# Example: Visualizing Output

Copyright © Scott Workman. 2024.  

In [None]:
import _init_paths

In [None]:
import torch
from torch.utils.data import DataLoader

from lmm import LMM
from nets import ops
from data import DTSDataset

import os
import glob
import argparse
import matplotlib
import numpy as np
from matplotlib import pyplot as plt

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--method', default='multitask', type=str)
parser.add_argument('--loss', default='student', type=str)
parser.add_argument('--decoder', default='mlp', type=str)
parser.add_argument('--save_dir', default='../logs/', type=str)
args = parser.parse_args([])

job_dir = "{}geo_{}_{}_{}".format(args.save_dir, args.decoder, args.loss,
                                  args.method)
print(job_dir)

## Load Model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ckpt_fname = glob.glob(
    f"{job_dir}/lightning_logs/version_0/checkpoints/epoch*.ckpt")[0]
model = LMM.load_from_checkpoint(ckpt_fname, strict=True, **vars(args))
model.to(device)
model.eval()
print()

## Speed Estimation

In [None]:
dts = DTSDataset('test', dense=True, full=True)

dataset = DataLoader(dts, batch_size=1, shuffle=False)

In [None]:
def make_sparse(speeds_dense, road_ids):
  agg_speed = ops.aggregate(speeds_dense, road_ids)
  agg_ids = [int(x) for x in ops.aggregate(road_ids, road_ids)]

  # fill road segments with aggregated speed
  speeds_sparse = torch.zeros_like(speeds_dense).float()
  for speed_, id_ in zip(agg_speed, agg_ids):
    mask_segment = road_ids == id_
    speeds_sparse[mask_segment] = speed_

  return speeds_sparse.squeeze()

In [None]:
take = 3

for batch_idx, data in zip(range(take), dataset):
  print(batch_idx)

  inputs, targets = [[y.to(device) for y in x] for x in data]
  im = inputs[0]
  tar_road, tar_speed, _, _, _, _, tar_road_id = targets

  with torch.no_grad():
    outputs = model(inputs)
    out_road, out_angle, out_speed = outputs

    if args.loss == "student":
      out_speed = out_speed[:, 0, ...].unsqueeze(1)

  out_speed_sparse = make_sparse(out_speed, tar_road_id)

  # convert to numpy
  t2n = lambda x: x.cpu().squeeze().numpy()
  im = t2n(im).transpose(1, 2, 0)
  out_speed = t2n(out_speed)
  out_speed_sparse = t2n(out_speed_sparse)
  tar_speed = t2n(tar_speed)
  tar_road = t2n(tar_road)

  plt.figure(figsize=(15, 15))
  plt.subplot(131)
  plt.imshow(im)
  tmp = np.ma.masked_where(tar_speed == 0, tar_speed)
  plt.imshow(tmp,
             cmap="RdYlGn",
             vmin=25,
             vmax=85,
             alpha=.8,
             interpolation="none")
  plt.title("Ground Truth")
  plt.axis("off")

  plt.subplot(132)
  plt.imshow(im)
  tmp = np.ma.masked_where(tar_road == 0, out_speed_sparse)
  plt.imshow(tmp,
             cmap="RdYlGn",
             vmin=25,
             vmax=85,
             alpha=.8,
             interpolation="none")
  plt.title("Prediction (Aggregated)")
  plt.axis("off")

  plt.subplot(133)
  plt.imshow(im)
  tmp = np.ma.masked_where(tar_road == 0, out_speed)
  plt.imshow(tmp,
             cmap="RdYlGn",
             vmin=25,
             vmax=85,
             alpha=.8,
             interpolation="none")
  plt.title("Prediction (Dense)")
  plt.axis("off")

  plt.show()

## Road Segmentation & Orientation Estimation

In [None]:
dts = DTSDataset('test')

dataset = DataLoader(dts, batch_size=1, shuffle=False)

In [None]:
def get_dense_angles(input_mask, angle_bin_mask, road_id_mask):
  valid = angle_bin_mask >= 0
  output_mask = np.zeros_like(road_id_mask) - 1

  # for each road segment
  for idx in range(1, int(np.max(road_id_mask) + 1)):
    mask_segment = road_id_mask == idx

    # intersect with valid points
    mask_points = mask_segment & valid
    if np.sum(mask_points) == 0:
      continue

    # most frequently occuring bin on this segment
    angle_max = np.argmax(np.bincount(input_mask[mask_points]))
    output_mask[mask_segment] = angle_max

  return output_mask

In [None]:
take = 3

for batch_idx, data in zip(range(take), dataset):
  print(batch_idx)

  inputs, targets = [[y.to(device) for y in x] for x in data]
  im = inputs[0]
  tar_road, _, _, tar_angle_bin, _, _, tar_road_id = targets

  with torch.no_grad():
    outputs = model(inputs)
    out_road, out_angle, _ = outputs
    out_road = (out_road > .5)
    out_angle_bin = torch.argmax(torch.softmax(out_angle, dim=1), axis=1)

  # convert to numpy
  t2n = lambda x: x.cpu().squeeze().numpy()
  im = t2n(im).transpose(1, 2, 0)
  tar_road = t2n(tar_road)
  tar_road_id = t2n(tar_road_id)
  tar_angle_bin = t2n(tar_angle_bin)
  out_road = t2n(out_road)
  out_angle_bin = t2n(out_angle_bin)

  # get dense angles
  tar_angle_dense = get_dense_angles(tar_angle_bin, tar_angle_bin, tar_road_id)
  out_angle_dense = get_dense_angles(out_angle_bin, tar_angle_bin, tar_road_id)

  plt.figure(figsize=(20, 20))
  plt.subplot(151)
  plt.imshow(im)
  plt.title("Image")
  plt.axis("off")

  plt.subplot(152)
  plt.imshow(tar_road, cmap='gray', vmin=0, vmax=1)
  plt.title("Road (GT)")
  plt.axis("off")

  plt.subplot(153)
  plt.imshow(out_road, cmap='gray', vmin=0, vmax=1)
  plt.title("Road (Pred)")
  plt.axis("off")

  plt.subplot(154)
  tmp = np.ma.masked_where(tar_road > 0, tar_road)
  plt.imshow(tmp, cmap="gray", interpolation="none")
  tmp = np.ma.masked_where(tar_angle_dense == -1, tar_angle_dense)
  plt.imshow(tmp, cmap="hsv", vmin=0, vmax=15, interpolation="none")
  plt.title("Orientation (GT)")
  plt.axis("off")

  plt.subplot(155)
  tmp = np.ma.masked_where(tar_road > 0, tar_road)
  plt.imshow(tmp, cmap="gray", interpolation="none")
  tmp = np.ma.masked_where(out_angle_dense == -1, out_angle_dense)
  plt.imshow(tmp, cmap="hsv", vmin=0, vmax=15, interpolation="none")
  plt.title("Orientation (Pred)")
  plt.axis("off")

  plt.show()