# Example: Visualizing Output

Copyright © Scott Workman. 2025.

In [None]:
import _init_paths

In [None]:
import torch

import cvd
from data import HCODataset, HCOPreDataset

import imageio
import numpy as np
from matplotlib import pyplot as plt

%matplotlib inline

### Visualize output

In [None]:
method = "refine_fuse"
assert method in ["refine_base", "refine_fuse", "ground"], "Invalid method"

base_dir = "../logs/{}/".format(method)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = cvd.CVD.load_from_checkpoint('{}lightning_logs/version_0/checkpoints/last.ckpt'.format(base_dir))
model.to(device).float()
model.eval()

print()

In [None]:
dataset = HCOPreDataset('val', zoom=16)

inds = np.random.randint(len(dataset), size=5)

for ind in inds:
  inputs, targets = dataset[ind]
  im_ground, im_context = inputs
  label_ground, valid_ground, _, _ = targets

  _, output = model([x.to(device).unsqueeze(0) for x in inputs])

  output = output.squeeze().detach().cpu().numpy()

  t2n = lambda x: x.cpu().numpy()
  label_ground = t2n(label_ground)
  valid_ground = t2n(valid_ground)
  im_ground = t2n(im_ground).transpose(1,2,0)

  output[valid_ground == 0] = 1
  label_ground[valid_ground == 0] = 1

  print(np.quantile(label_ground, [.2,.98]), np.quantile(output, [.2,.98]))

  plt.figure(figsize=(10,10))
  plt.subplot(131)
  plt.imshow(im_ground)
  plt.axis('off')
  plt.subplot(132)
  plt.imshow(label_ground, 'gray_r', vmin=0)
  plt.axis('off')
  plt.subplot(133)
  plt.imshow(output, 'gray_r', vmin=0, vmax=label_ground.max())
  plt.axis('off')

  plt.show()

### Visualize output (estimating heights)

In [None]:
base_dir = "../logs/refine/"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = cvd.CVD.load_from_checkpoint('{}lightning_logs/version_0/checkpoints/last.ckpt'.format(base_dir))
model.to(device).float()
model.eval()

print()

In [None]:
dataset = HCODataset('val', zoom=16)

inds = np.random.randint(len(dataset), size=5)

for ind in inds:
  inputs, targets = dataset[ind]
  im_ground, im_overhead, depth_overhead, pano_yaw, tilt_yaw, tilt_pitch, yaw, pitch, gsd = inputs
  label_ground, valid_ground, label_overhead, valid_overhead = targets

  output_overhead, output = model([x.to(device).unsqueeze(0) for x in inputs])

  output = output.squeeze().detach().cpu().numpy()
  output_overhead = output_overhead.squeeze().detach().cpu().numpy()

  t2n = lambda x: x.cpu().numpy()
  im_overhead = t2n(im_overhead.squeeze()).transpose(1,2,0)
  label_overhead = t2n(label_overhead)
  label_ground = t2n(label_ground)
  valid_ground = t2n(valid_ground)
  im_ground = t2n(im_ground).transpose(1,2,0)

  output[valid_ground == 0] = 1
  label_ground[valid_ground == 0] = 1
  
  output_overhead[valid_overhead ==0] = np.nan

  print(np.quantile(label_ground, [.2,.98]), np.quantile(output, [.2,.98]))

  plt.figure(figsize=(15,15))
  plt.subplot(161)
  plt.imshow(im_overhead)
  plt.axis('off')
  plt.subplot(162)
  plt.imshow(label_overhead, vmin=0)
  plt.axis('off')
  plt.subplot(163)
  plt.imshow(output_overhead, vmin=0, vmax=np.nanmax(label_overhead))
  plt.axis('off')
  plt.subplot(164)
  plt.imshow(im_ground)
  plt.axis('off')
  plt.subplot(165)
  plt.imshow(label_ground, 'gray_r', vmin=0)
  plt.axis('off')
  plt.subplot(166)
  plt.imshow(output, 'gray_r', vmin=0, vmax=label_ground.max())
  plt.axis('off')

  plt.show()