In [1]:
import os
import cv2
import random
import scipy as sp
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
np.set_printoptions(suppress=True, precision=5)

# dl imports
import torch

## CARLA dataset

In [2]:
from config import GlobalConfig
from data import CARLA_Data

root_dir = '/home/surya/Downloads/transfuser-2022/data/demo/scenario1/'
config = GlobalConfig()
demo_set = CARLA_Data(root=root_dir, config=config, routeKey='route0')
print(f"There are {len(demo_set)} samples in Demo dataset")

100%|██████████| 1/1 [00:00<00:00, 171.96it/s]
There are 98 samples in Demo dataset


Create pytorch style dataloaders

In [3]:
from torch.utils.data import DataLoader
dataloader_demo = DataLoader(demo_set, shuffle=False, batch_size=2, num_workers=4)

In [4]:
sample_data = next(iter(dataloader_demo))
print(f"sample data is of type {type(sample_data)} and has following keys")

for k,v in sample_data.items():
    print(k, list(v.shape))
    
del sample_data

sample data is of type <class 'dict'> and has following keys
rgb [2, 3, 160, 704]
bev [2, 160, 160]
depth [2, 160, 704]
semantic [2, 160, 704]
speed [2]
x_command [2]
y_command [2]
target_point [2, 2]
target_point_image [2, 1, 256, 256]
lidar [2, 2, 256, 256]
label [2, 20, 7]
ego_waypoint [2, 4, 2]


## Load pretrained model

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

from model import LidarCenterNet
model = LidarCenterNet(config, device, config.backbone, image_architecture='regnety_032', 
                           lidar_architecture='regnety_032', use_velocity=False)
model.to(device);
model.config.debug = True

model.eval();
checkpt = torch.load('/home/surya/Downloads/transfuser-2022/model_ckpt/transfuser/transfuser_regnet032_seed1_39.pth', map_location=device)
model.load_state_dict(checkpt)

<All keys matched successfully>

## Helper functions

In [6]:
# from collections import OrderedDict
# new_state_dict = OrderedDict()

# for k,v in checkpt.items():
#     new_key = k.replace("module.", "")
#     if new_key != '_model.lidar_encoder._model.stem.conv.weight':
#         new_state_dict[new_key] = v
# torch.save(new_state_dict, 'transfuser_regnet032_seed1_39.pth')

In [7]:
def convert_waypoints_to_image(waypoints):
    x = 128
    y = 256
    yaw = 0
    c, s = np.cos(yaw), np.sin(yaw)
    # use y x because coordinate is changed
    r1_to_world = np.array([[c, -s, x], [s, c, y], [0, 0, 1]])

    # convert to image space
    # need to negate y componet as we do for lidar points
    # we directly construct points in the image coordiante
    # for lidar, forward +x, right +y
    points = waypoints.copy()
    points[:, 0] *= -1
    points = points * 8
    points = points[:, [1, 0]]
    points = np.concatenate((points, np.ones_like(points[:, :1])), axis=-1)
    points = r1_to_world @ points.T
    points = (points.T)[:, :2].astype(np.int32)
    points = np.clip(points, 0, 256)
    return points

In [8]:
def get_rotated_bbox(bbox):
    x, y, w, h, yaw, _, _  =  bbox

    bbox = np.array([[h,   w, 1],
                     [h,  -w, 1],
                     [-h, -w, 1],
                     [-h,  w, 1],
                     [0, 0, 1],
                ])
    
    # The height and width of the bounding box value was changed by this factor 
    # during data collection. Fix that for future datasets and remove    
    bbox[:, :2] /= 2
    bbox[:, :2] = bbox[:, [1, 0]]

    c, s = np.cos(yaw), np.sin(yaw)
    # use y x because coordinate is changed
    r1_to_world = np.array([[c, -s, x], [s, c, y], [0, 0, 1]])
    bbox = r1_to_world @ bbox.T
    bbox = bbox.T
    return bbox


def draw_bounding_box(ax, bbox, **kwargs):
    lns = []
    lns.append(ax.plot( [bbox[0,0], bbox[1,0]], [bbox[0,1], bbox[1,1]], **kwargs)[0])
    lns.append(ax.plot( [bbox[1,0], bbox[2,0]], [bbox[1,1], bbox[2,1]], **kwargs)[0])
    lns.append(ax.plot( [bbox[2,0], bbox[3,0]], [bbox[2,1], bbox[3,1]], **kwargs)[0])
    lns.append(ax.plot( [bbox[3,0], bbox[0,0]], [bbox[3,1], bbox[0,1]], **kwargs)[0])
    return lns

In [9]:
def smooth_attention_map(attn_map, reshape_size):
    attn_map = attn_map.sum(axis=0)
    attn_map = attn_map / attn_map.sum(axis=1)[:, np.newaxis]
    up_attn = cv2.resize(attn_map, reshape_size)
    smoothed_attn_map = sp.ndimage.filters.gaussian_filter(up_attn, [1,1], mode='constant')
    return smoothed_attn_map

## Visualization class

In [10]:
class Visualizer:
    def __init__(self):
        self.fig = plt.figure(figsize=(20, 10))
        self.fig.suptitle("Transfuser output")
        self.fig.tight_layout()
        gs = self.fig.add_gridspec(3, 3,  height_ratios = [2,1,1], width_ratios = [2,1,1])
        self.rgb_axes = self.fig.add_subplot(gs[0, :])
        self.depth_axes = self.fig.add_subplot(gs[1, 0])
        self.semantic_axes = self.fig.add_subplot(gs[2, 0])
        self.bev_axes = self.fig.add_subplot(gs[1:, 1])
        self.bbox_axes = self.fig.add_subplot(gs[1:, 2])
        
        self.color_code = np.array([
            [0, 0, 0],       # black
            [128, 128, 128], # grey
            [255, 255, 0]    # yellow
        ])
    
    def plotData(self, rgb_image, depth_image, semantic_image, bev_image,
                 lidar_data, tgt_points, gt_boxes, 
                 pred_points = None, pred_boxes=None, 
                 img_attn_map = None, lidar_attn_map = None):
        
        # clear previous data before plotting
        self.clearAxes()
        
        # rgb image with attention map
        self.rgb_axes.imshow(rgb_image)
        if img_attn_map is not None:
            self.rgb_axes.imshow(smooth_img_attn_map, cmap='inferno', alpha=0.3)
        self.rgb_axes.set(xticks=[], yticks=[])
        
        # depth image
        self.depth_axes.imshow(depth_image)
        self.depth_axes.set(xticks=[], yticks=[])
        
        # semantic segmentation
        self.semantic_axes.imshow(semantic_image)
        self.semantic_axes.set(xticks=[], yticks=[])
        
        # BEV class prediction
        self.bev_axes.imshow(self.color_code[bev_image])
        self.bev_axes.set(xticks=[], yticks=[])
        
        # lidar data with waypoints
        self.bbox_axes.imshow(lidar_data)
        self.bbox_axes.set(xticks=[], yticks=[])
        self.bbox_axes.plot(tgt_points[:,0], tgt_points[:,1], 'go', linewidth =3)
        if pred_points is not None:
            self.bbox_axes.plot(pred_points[:,0], pred_points[:,1], 'ro', linewidth =3)
        
        # bounding boxes
        for bbox in gt_boxes:
            draw_bounding_box(self.bbox_axes, bbox, color='white')
        
        if pred_boxes is not None:
            for bbox in pred_boxes:
                draw_bounding_box(self.bbox_axes, bbox, color='cyan')
        
        # lidar attention map
        if lidar_attn_map is not None:
            self.bbox_axes.imshow(lidar_attn_map, cmap='inferno', alpha=0.3)

        self.fig.subplots_adjust(wspace=0, hspace=0)

    def clearAxes(self):
        self.rgb_axes.clear()
        self.depth_axes.clear()
        self.semantic_axes.clear()
        self.bev_axes.clear()
        self.bbox_axes.clear()
        
    def saveFigure(self, outputPath):
        self.fig.savefig(outputPath)

## Demo video

In [11]:
%matplotlib agg
visualizer = Visualizer()

frameIdx = 0

for data in tqdm(dataloader_demo):

    # load data to device, according to type
    for k in ['rgb', 'depth', 'lidar', 'label', 'ego_waypoint', \
              'target_point', 'target_point_image', 'speed']:
        data[k] = data[k].to(device, torch.float32)
    for k in ['semantic', 'bev']:
        data[k] = data[k].to(device, torch.long)
    
    # get model predictions
    _, outputs = model(data)

    # iterate through each sample in batch 
    bs = data['rgb'].shape[0]
    for i in range(bs):
        # extract input data
        rgb_image = data['rgb'][i].permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)
        lidar_data = data['lidar'][i].detach().cpu().numpy().transpose(1,2,0)
        tgt_waypoints = data['ego_waypoint'][i].detach().cpu().numpy()
        tgt_waypoints_image = convert_waypoints_to_image(tgt_waypoints)
        gt_boxes = data['label'][i].detach().cpu().numpy()
        gt_boxes = gt_boxes[gt_boxes.sum(axis=-1) != 0.]
        rotated_bboxes_gt = [get_rotated_bbox(bbox)[:, :2] for bbox in gt_boxes]

        # extract model predictions
        pred_depth = outputs['pred_depth'][i]
        indices = np.argmax(outputs['pred_semantic'], axis=1)
        pred_semantic = np.array(config.classes_list)[indices[i, ...], ...].astype('uint8')        
        pred_bev = outputs['pred_bev'][i].argmax(axis=0).astype(np.uint8)
        pred_waypoints = outputs['pred_wp'][i]
        pred_waypoints_image = convert_waypoints_to_image(pred_waypoints)
        pred_boxes = outputs['detections'][i]
        rotated_bboxes_pred = [get_rotated_bbox(bbox)[:, :2] for bbox in pred_boxes]
        attn_map = np.sum(outputs['attn_map'][i], axis=0)
        image_attn_map = attn_map[0:110, 0:110].reshape(110, 5, 22)
        smooth_img_attn_map = smooth_attention_map(image_attn_map, reshape_size=(704, 160))
        lidar_attn_map = attn_map[110:, 110:].reshape(64, 8, 8)
        smooth_lidar_attn_map = smooth_attention_map(lidar_attn_map, reshape_size=(256, 256))

        # plot all data
        visualizer.plotData(rgb_image, depth_image = pred_depth, semantic_image=pred_semantic, 
            bev_image = pred_bev, lidar_data = lidar_data, tgt_points = tgt_waypoints_image, 
            gt_boxes = rotated_bboxes_gt, pred_points = pred_waypoints_image, 
            pred_boxes = rotated_bboxes_pred, 
            #img_attn_map = smooth_img_attn_map, 
            #lidar_attn_map = smooth_lidar_attn_map
       )

        # save figure
        visualizer.saveFigure(f"outputFolder/Frame{frameIdx}.png")
        frameIdx +=1

  topk_clses = topk_inds // (height * width)
  topk_ys = topk_inds // width
100%|██████████| 49/49 [01:57<00:00,  2.39s/it]
