## Introduction
This notebook is for the reproduction of paper: "A Deep Learning-based Radar and Camera Sensor Fusion Architecture for Object Detection" (https://arxiv.org/abs/2005.07431) using PyTorch. The original implementation of this project in Keras is avaiable  at https://github.com/TUMFTM/CameraRadarFusionNet.
A blog about this reproduction can be found at https://cutt.ly/QvaAsuu.

## Requirements
1. Pytorch
2. Nuscenes dev-kit (https://github.com/nutonomy/nuscenes-devkit)
3. Nuscenes dataset (https://www.nuscenes.org/nuscenes)
4. CRFNet (https://github.com/TUMFTM/CameraRadarFusionNet)(We mainly used some exsting utils for radar signal processing in the existing project)

In [1]:
# import pytorch packages
import torch
import torch.nn as nn
import torchvision.models as tv_models
import torchvision as tv
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import torch.nn.functional as F

#Nuscenes dev-kit
from nuscenes.nuscenes import NuScenes
from nuscenes.utils.geometry_utils import box_in_image, view_points, BoxVisibility, points_in_box

#Libraries for data fusion
import sys
import math
import time
import cv2
import json
import numpy as np
from pyquaternion import Quaternion
from PIL import Image
from nuscenes.utils.data_classes import PointCloud
from crfnet.utils import radar
from crfnet.utils.nuscenes_helper import get_sensor_sample_data


In [3]:
# Data importing
#nusc = NuScenes(version='v1.0-mini', dataroot='F:\\CameraRadarFusionNet\\crfnet\\data\\nuscenes', verbose=True)
nusc = NuScenes(version='v1.0-mini', dataroot='E:\\Documents\\Robotics\\Deep_Learning\\Anaconda_CRF\\CameraRadarFusionNet\\crfnet\\data\\nuscenes', verbose=True)

Loading NuScenes tables for version v1.0-mini...
23 category,
8 attribute,
4 visibility,
911 instance,
12 sensor,
120 calibrated_sensor,
31206 ego_pose,
8 log,
10 scene,
404 sample,
31206 sample_data,
18538 sample_annotation,
4 map,
Done loading in 0.6 seconds.
Reverse indexing ...
Done reverse indexing in 0.1 seconds.


## Dataset
The construction of custome NuScenes includes 4 steps
1. Initialization of new class map and sample tokens
2. Acquire camera, radar and annotation data from each sample(by nuscenes-devkit)
3. Radar and camera data fusion
4. Transform 3D annotations in NuScenes to 2D annotation box

In [4]:
# Custome pytorch NuScenes dataset
class nuscenes_dataset(Dataset):       
    def __init__(self,nusc,
                category_mapping=None
                ,image_min_side= 360
                ,image_max_side = 640):
            self.nusc=nusc
            self.sample_tokens = {}
            self.image_data = dict()
            self.radar_sensors = ['RADAR_FRONT']
            self.camera_sensors = ['CAM_FRONT']
            self.only_radar_annotated = False
            self.normalize_bbox = False # True for normalizing the bbox to [0,1]
            self.image_min_side = image_min_side
            self.image_max_side = image_max_side
            self.normalize_bbox = False
            self.n_sweeps = 1
            self.classes, self.labels = self._get_class_label_mapping([c['name'] for c in nusc.category], category_mapping)
            prog = 0
            scene_indices = range(len(nusc.scene))
            self.class_map = {0: -1,1: 7,2: 7, 3: 7,4: 7,5: 7,6: 7,7: 7,8: -1,9: -1,10: -1,11: -1,12: -1,
                              13: 2,14: 3,15: 3,16: 0,17: -1,18: 5,19: 5,20: 1,21: 6,22: 4,23: -1}
            for scene_index in scene_indices:
                first_sample_token = nusc.scene[scene_index]['first_sample_token']
                nbr_samples = nusc.scene[scene_index]['nbr_samples']

                curr_sample = nusc.get('sample', first_sample_token)

                for _ in range(nbr_samples):
                    self.sample_tokens[prog] = curr_sample['token']
                    if curr_sample['next']:
                        next_token = curr_sample['next']
                        curr_sample = nusc.get('sample', next_token)
                    prog += 1    
        
    def __len__(self):
        return len(self.sample_tokens)
           
    def __getitem__(self, index):
        # get data 
        sample_token = self.sample_tokens[index]
        sample = nusc.get('sample',sample_token)
        image_data = get_sensor_sample_data(nusc, sample, self.camera_sensors[0])
        radar_data = get_sensor_sample_data(nusc, sample, self.radar_sensors[0])
        camera_token = sample['data'][self.camera_sensors[0]]
        radar_token = sample['data'][self.radar_sensors[0]]
        #nusc.render_sample_data(camera_token)
        
        # image fusing 
        height = (0,3)
        image_target_shape = (self.image_min_side, self.image_max_side)
        image_plus_data = imageplus_creation(self.nusc, image_data, radar_data, radar_token, camera_token, height, image_target_shape)
        image_plus_data = np.transpose(image_plus_data)
        image_plus_data = torch.tensor(image_plus_data)
        image_plus_data = torch.transpose(image_plus_data,2,1)
        image_plus_data = image_plus_data[(0,1,2,5,18) ,:,: ]
        image_plus_data[(0,1,2),:,:] = (image_plus_data[(0,1,2),:,:]-0.5)*255 # Normalization
        
        # annotation
        annotation_2D = self.create_annotations(sample_token,self.camera_sensors)
        annotation = list()
        n=0
        for label, box in zip(annotation_2D['labels'], annotation_2D['bboxes']):
             if self.class_map[label]>=0:
                label = self.class_map[label]
                annotation.append( torch.tensor(np.append(box, label)))
        if not annotation == []:
            annotation = torch.reshape(torch.cat(annotation,dim=0),(-1,5))
        else:
            annotation =  torch.tensor(annotation)
        return image_plus_data, annotation
    
    @staticmethod
    def _get_class_label_mapping(category_names, category_mapping):
        """
        :param category_mapping: [dict] Map from original name to target name. Subsets of names are supported. 
            e.g. {'pedestrian' : 'pedestrian'} will map all pedestrian types to the same label

        :returns: 
            [0]: [dict of (str, int)] mapping from category name to the corresponding index-number
            [1]: [dict of (int, str)] mapping from index number to category name
        """
        # Initialize local variables
        original_name_to_label = {}
        original_category_names = category_names.copy()
        original_category_names.append('bg')
        if category_mapping is None:
            # Create identity mapping and ignore no class
            category_mapping = dict()
            for cat_name in category_names:
                category_mapping[cat_name] = cat_name

        # List of unique class_names
        selected_category_names = set(category_mapping.values()) # unordered
        selected_category_names = list(selected_category_names)
        selected_category_names.sort() # ordered
      
        # Create the label to class_name mapping
        label_to_name = { label:name for label, name in enumerate(selected_category_names)}
        label_to_name[len(label_to_name)] = 'bg' # Add the background class

        # Create original class name to label mapping
        for label, label_name in label_to_name.items():

            # Looking for all the original names that are adressed by label name
            targets = [original_name for original_name in original_category_names if label_name in original_name]

            # Assigning the same label for all adressed targets
            for target in targets:
                
                # Check for ambiguity
                assert target not in original_name_to_label.keys(), 'ambigous mapping found for (%s->%s)'%(target, label_name)
                
                # Assign label to original name
                # Some label_names will have the same label, which is totally fine
                original_name_to_label[target] = label

        # Check for correctness
        actual_labels = original_name_to_label.values()
        expected_labels = range(0, max(actual_labels)+1) # we want to start labels at 0
        assert all([label in actual_labels for label in expected_labels]), 'Expected labels do not match actual labels'

        return original_name_to_label, label_to_name
    
    def create_annotations(self, sample_token, sensor_channels):
            """
            Create annotations for the the given sample token.

            1 bounding box vector contains:


            :param sample_token: the sample_token to get the annotation for
            :param sensor_channels: list of channels for cropping the labels, e.g. ['CAM_FRONT', 'RADAR_FRONT']
                This works only for CAMERA atm

            :returns: 
                annotations dictionary:
                {
                    'labels': [] # <list of n int>  
                    'bboxes': [] # <list of n x 4 float> [xmin, ymin, xmax, ymax]
                    'distances': [] # <list of n float>  Center of box given as x, y, z.
                    'visibilities': [] # <list of n float>  Visibility of annotated object
                }
            """

            if any([s for s in sensor_channels if 'RADAR' in s]):
                print("[WARNING] Cropping to RADAR is not supported atm")
                sensor_channels = [c for c in sensor_channels if 'CAM' in sensor_channels]

            sample = self.nusc.get('sample', sample_token)
            annotations_count = 0
            annotations = {
                'labels': [], # <list of n int>  
                'bboxes': [], # <list of n x 4 float> [xmin, ymin, xmax, ymax]
                'distances': [], # <list of n float>  Center of box given as x, y, z.
                'visibilities': [],
                'num_radar_pts':[] #<list of n int>  number of radar points that cover that annotation
                }

            # Camera parameters
            for selected_sensor_channel in sensor_channels:
                sd_rec = self.nusc.get('sample_data', sample['data'][selected_sensor_channel])

                # Create Boxes:
                _, boxes, camera_intrinsic = self.nusc.get_sample_data(sd_rec['token'], box_vis_level=BoxVisibility.ANY)
                imsize_src = (sd_rec['width'], sd_rec['height']) # nuscenes has (width, height) convention
                #print(type(boxes[0]))

                bbox_resize = [ 1. / sd_rec['height'], 1. / sd_rec['width'] ]
                if not self.normalize_bbox:
                    bbox_resize[0] *= float(self.image_min_side)
                    bbox_resize[1] *= float(self.image_max_side)

                # Create labels for all boxes that are visible
                for box in boxes:

                    # Add labels to boxes 
                    if box.name in self.classes:
                        box.label = self.classes[box.name]
                        # Check if box is visible and transform box to 1D vector
                        if box_in_image(box=box, intrinsic=camera_intrinsic, imsize=imsize_src, vis_level=BoxVisibility.ANY):

                            ## Points in box method for annotation filterS
                            # check if bounding box has an according radar point
                            if self.only_radar_annotated == 2:

                                pcs, times = RadarPointCloud.from_file_multisweep(self.nusc, sample, self.radar_sensors[0], \
                                    selected_sensor_channel, nsweeps=self.n_sweeps, min_distance=0.0, merge=False)

                                for pc in pcs:
                                    pc.points = radar.enrich_radar_data(pc.points)    

                                if len(pcs) > 0:
                                    radar_sample = np.concatenate([pc.points for pc in pcs], axis=-1)
                                else:
                                    print("[WARNING] only_radar_annotated=2 and sweeps=0 removes all annotations")
                                    radar_sample = np.zeros(shape=(len(radar.channel_map), 0))
                                radar_sample = radar_sample.astype(dtype=np.float32)

                                mask = points_in_box(box, radar_sample[0:3,:])
                                if True not in mask:
                                    continue 


                            # If visible, we create the corresponding label
                            box2d = box.box2d(camera_intrinsic) # returns [xmin, ymin, xmax, ymax]
                            box2d[0] *= bbox_resize[1]
                            box2d[1] *= bbox_resize[0]
                            box2d[2] *= bbox_resize[1]
                            box2d[3] *= bbox_resize[0]

                            annotations['bboxes'].insert(annotations_count, box2d)
                            annotations['labels'].insert(annotations_count, box.label)
                            annotations['num_radar_pts'].insert(annotations_count, self.nusc.get('sample_annotation', box.token)['num_radar_pts'])

                            distance =  (box.center[0]**2 + box.center[1]**2 + box.center[2]**2)**0.5
                            annotations['distances'].insert(annotations_count, distance)
                            annotations['visibilities'].insert(annotations_count, int(self.nusc.get('sample_annotation', box.token)['visibility_token']))
                            annotations_count += 1
                    else:
                        # The current name has been ignored
                        pass

            annotations['labels'] = np.array(annotations['labels'])
            annotations['bboxes'] = np.array(annotations['bboxes'])
            annotations['distances'] = np.array(annotations['distances'])
            annotations['num_radar_pts'] = np.array(annotations['num_radar_pts'])
            annotations['visibilities'] = np.array(annotations['visibilities'])

            # num_radar_pts mathod for annotation filter
            if self.only_radar_annotated == 1:

                anns_to_keep = np.where(annotations['num_radar_pts'])[0]

                for key in annotations:
                    annotations[key] = annotations[key][anns_to_keep]

            return annotations
    
#Data Fusion Functions

def _resize_image(image_data, target_shape):
    """
    Perfomrs resizing of the image and calculates a matrix to adapt the intrinsic camera matrix
    :param image_data: [np.array] with shape (height x width x 3)
    :param target_shape: [tuple] with (width, height)
    :return resized image: [np.array] with shape (height x width x 3)
    :return resize matrix: [numpy array (3 x 3)]
    """
    # print('resized', type(image_data))
    stupid_confusing_cv2_size_because_width_and_height_are_in_wrong_order = (target_shape[1], target_shape[0])
    resized_image = cv2.resize(image_data, stupid_confusing_cv2_size_because_width_and_height_are_in_wrong_order)
    resize_matrix = np.eye(3, dtype=resized_image.dtype)
    resize_matrix[1, 1] = target_shape[0]/image_data.shape[0]
    resize_matrix[0, 0] = target_shape[1]/image_data.shape[1]
    return resized_image, resize_matrix

def _radar_transformation(radar_data, height=None):
    """
    Transforms the given radar data with height z = 0 and another height as input using extrinsic radar matrix to vehicle's co-sy
    This function appends the distance to the radar point.
    Parameters:
    :param radar_data: [numpy array] with radar parameter (e.g. velocity) in rows and radar points for one timestep in columns
        Semantics: x y z dyn_prop id rcs vx vy vx_comp vy_comp is_quality_valid ambig_state x_rms y_rms invalid_state pdh0 distance
    :param radar_extrinsic: [numpy array (3x4)] that consists of the extrinsic parameters of the given radar sensor
    :param height: [tuple] (min height, max height) that defines the (unknown) height of the radar points
    Returns:
    :returns radar_data: [numpy array (m x no of points)] that consists of the transformed radar points with z = 0
    :returns radar_xyz_endpoint: [numpy array (3 x no of points)] that consits of the transformed radar points z = height  
    """

    # Field of view (global)
    ELEVATION_FOV_SR = 20
    ELEVATION_FOV_FR = 14  

    # initialization
    num_points = radar_data.shape[1]

    # Radar points for the endpoint
    radar_xyz_endpoint = radar_data[0:3,:].copy()

    # variant 1: constant height substracted by RADAR_HEIGHT
    RADAR_HEIGHT = 0.5
    if height:
        radar_data[2, :] = np.ones((num_points,)) * (height[0] - RADAR_HEIGHT) # lower points
        radar_xyz_endpoint[2, :] = np.ones((num_points,)) * (height[1] - RADAR_HEIGHT) # upper points

    # variant 2: field of view
    else:
        dist = radar_data[-1,:]
        count = 0
        for d in dist:
            # short range mode
            if d <= 70: 
                radar_xyz_endpoint[2, count] = -d * np.tan(ELEVATION_FOV_SR/2)

            # long range mode
            else:
                radar_xyz_endpoint[2, count] = -d * np.tan(ELEVATION_FOV_FR/2)

            count += 1

    return radar_data, radar_xyz_endpoint

def _create_line(P1, P2, img):
    """
    Produces and array that consists of the coordinates and intensities of each pixel in a line between two points
    :param P1: [numpy array] that consists of the coordinate of the first point (x,y)
    :param P2: [numpy array] that consists of the coordinate of the second point (x,y)
    :param img: [numpy array] the image being processed
    :return itbuffer: [numpy array] that consists of the coordinates and intensities of each pixel in the radii (shape: [numPixels, 3], row = [x,y])     
    """
    # define local variables for readability
    imageH = img.shape[0]
    imageW = img.shape[1]

    P1X = P1[0]
    P1Y = P1[1]
    P2X = P2[0]
    P2Y = P2[1]

    # difference and absolute difference between points
    # used to calculate slope and relative location between points
    dX = P2X - P1X
    dY = P2Y - P1Y
    dXa = np.abs(dX)
    dYa = np.abs(dY)

    # predefine numpy array for output based on distance between points
    itbuffer = np.empty(
        shape=(np.maximum(int(dYa), int(dXa)), 2), dtype=np.float32)
    itbuffer.fill(np.nan)

    # Obtain coordinates along the line using a form of Bresenham's algorithm
    negY = P1Y > P2Y
    negX = P1X > P2X
    if P1X == P2X:  # vertical line segment
        itbuffer[:, 0] = P1X
        if negY:
            itbuffer[:, 1] = np.arange(P1Y - 1, P1Y - dYa - 1, -1)
        else:
            itbuffer[:, 1] = np.arange(P1Y+1, P1Y+dYa+1)
    elif P1Y == P2Y:  # horizontal line segment
        itbuffer[:, 1] = P1Y
        if negX:
            itbuffer[:, 0] = np.arange(P1X-1, P1X-dXa-1, -1)
        else:
            itbuffer[:, 0] = np.arange(P1X+1, P1X+dXa+1)
    else:  # diagonal line segment
        steepSlope = dYa > dXa
        if steepSlope:
            slope = dX.astype(np.float32)/dY.astype(np.float32)
            if negY:
                itbuffer[:, 1] = np.arange(P1Y-1, P1Y-dYa-1, -1)
            else:
                itbuffer[:, 1] = np.arange(P1Y+1, P1Y+dYa+1)
            itbuffer[:, 0] = (slope*(itbuffer[:, 1]-P1Y)).astype(np.int) + P1X
        else:
            slope = dY.astype(np.float32)/dX.astype(np.float32)
            if negX:
                itbuffer[:, 0] = np.arange(P1X-1, P1X-dXa-1, -1)
            else:
                itbuffer[:, 0] = np.arange(P1X+1, P1X+dXa+1)
            itbuffer[:, 1] = (slope*(itbuffer[:, 0]-P1X)).astype(np.int) + P1Y

    # Remove points outside of image
    colX = itbuffer[:, 0].astype(int)
    colY = itbuffer[:, 1].astype(int)
    itbuffer = itbuffer[(colX >= 0) & (colY >= 0) &
                        (colX < imageW) & (colY < imageH)]

    return itbuffer

def _create_vertical_line(P1, P2, img):
    """
    Produces and array that consists of the coordinates and intensities of each pixel in a line between two points
    :param P1: [numpy array] that consists of the coordinate of the first point (x,y)
    :param P2: [numpy array] that consists of the coordinate of the second point (x,y)
    :param img: [numpy array] the image being processed
    :return itbuffer: [numpy array] that consists of the coordinates and intensities of each pixel in the radii (shape: [numPixels, 3], row = [x,y])     
    """
    # define local variables for readability
    imageH = img.shape[0]
    imageW = img.shape[1]

    # difference and absolute difference between points
    # used to calculate slope and relative location between points
    P1_y = int(P1[1])
    P2_y = int(P2[1])
    dX = 0
    dY = P2_y - P1_y
    if dY == 0:
        dY = 1
    dXa = np.abs(dX)
    dYa = np.abs(dY)

    # predefine numpy array for output based on distance between points
    itbuffer = np.empty(
        shape=(np.maximum(int(dYa), int(dXa)), 2), dtype=np.float32)
    itbuffer.fill(np.nan)

    # vertical line segment
    itbuffer[:, 0] = int(P1[0])
    if P1_y > P2_y:
        # Obtain coordinates along the line using a form of Bresenham's algorithm
        itbuffer[:, 1] = np.arange(P1_y - 1, P1_y - dYa - 1, -1)
    else:
        itbuffer[:, 1] = np.arange(P1_y+1, P1_y+dYa+1)

    # Remove points outside of image
    colX = itbuffer[:, 0].astype(int)
    colY = itbuffer[:, 1].astype(int)
    itbuffer = itbuffer[(colX >= 0) & (colY >= 0) &
                        (colX < imageW) & (colY < imageH)]

    return itbuffer

def _radar2camera(image_data, radar_data, radar_xyz_endpoints, clear_radar=False):
    """

    Calculates a line of two radar points and puts the radar_meta data as additonal layers to the image -> image_plus
    :param image_data: [numpy array (900 x 1600 x 3)] of image data
    :param radar_data: [numpy array (xyz+meta x no of points)] that consists of the transformed radar points with z = 0
        default semantics: x y z dyn_prop id rcs vx vy vx_comp vy_comp is_quality_valid ambig_state x_rms y_rms invalid_state pdh0 vx_rms vy_rms distance
    :param radar_xyz_endpoints: [numpy array (3 x no of points)] that consits of the transformed radar points z = height
    :param clear_radar: [boolean] True if radar data should be all zero
    :return image_plus: a numpy array (900 x 1600 x (3 + number of radar_meta (e.g. velocity)))
    """

    radar_meta_count = radar_data.shape[0]-3
    radar_extension = np.zeros(
        (image_data.shape[0], image_data.shape[1], radar_meta_count), dtype=np.float32)
    no_of_points = radar_data.shape[1]

    if clear_radar:
        pass # we just don't add it to the image
    else:
        for radar_point in range(0, no_of_points):
            projection_line = _create_vertical_line(
                radar_data[0:2, radar_point], radar_xyz_endpoints[0:2, radar_point], image_data)

            for pixel_point in range(0, projection_line.shape[0]):
                y = projection_line[pixel_point, 1].astype(int)
                x = projection_line[pixel_point, 0].astype(int)

                # Check if pixel is already filled with radar data and overwrite if distance is less than the existing
                if not np.any(radar_extension[y, x]) or radar_data[-1, radar_point] < radar_extension[y, x, -1]:
                    radar_extension[y, x] = radar_data[3:, radar_point]


    image_plus = np.concatenate((image_data, radar_extension), axis=2)

    return image_plus

def view_points(points: np.ndarray, view: np.ndarray, normalize: bool):
    """
    This function is a modification of nuscenes.geometry_utils.view_points function
    This is a helper class that maps 3d points to a 2d plane. It can be used to implement both perspective and
    orthographic projections. It first applies the dot product between the points and the view. By convention,
    the view should be such that the data is projected onto the first 2 axis. It then optionally applies a
    normalization along the third dimension.
    For a perspective projection the view should be a 3x3 camera matrix, and normalize=True
    For an orthographic projection with translation the view is a 3x4 matrix and normalize=False
    For an orthographic projection without translation the view is a 3x3 matrix (optionally 3x4 with last columns
     all zeros) and normalize=False
    :param points: <np.float32: 3, n> Matrix of points, where each point (x, y, z) is along each column.
    :param view: <np.float32: n, n>. Defines an arbitrary projection (n <= 4).
        The projection should be such that the corners are projected onto the first 2 axis.
    :param normalize: Whether to normalize the remaining coordinate (along the third axis).
    :return: <np.float32: 3, n>. Mapped point. If normalize=False, the third coordinate is the height.
    """

    output = points

    assert view.shape[0] <= 4
    assert view.shape[1] <= 4
    assert points.shape[0] >= 3
    points = output[0:3,:]

    viewpad = np.eye(4)
    viewpad[:view.shape[0], :view.shape[1]] = view

    nbr_points = points.shape[1]

    # Do operation in homogenous coordinates
    points = np.concatenate((points, np.ones((1, nbr_points))))
    points = np.dot(viewpad, points)
    points = points[:3, :]

    if normalize:
        points = points / points[2:3, :].repeat(3, 0).reshape(3, nbr_points)

    output[0:3,:] = points
    return output

def map_pointcloud_to_image(nusc, radar_points, pointsensor_token, camera_token, target_resolution=(None,None)):
    """
    Given a point sensor (lidar/radar) token and camera sample_data token, load point-cloud and map it to the image
    plane.
    :param radar_pints: [list] list of radar points
    :param pointsensor_token: [str] Lidar/radar sample_data token.
    :param camera_token: [str] Camera sample_data token.
    :param target_resolution: [tuple of int] determining the output size for the radar_image. None for no change
    :return (points <np.float: 2, n)
    """

    # Initialize the database
    cam = nusc.get('sample_data', camera_token)
    pointsensor = nusc.get('sample_data', pointsensor_token)

    pc = PointCloud(radar_points)

    # Points live in the point sensor frame. So they need to be transformed via global to the image plane.
    # First step: transform the point-cloud to the ego vehicle frame for the timestamp of the sweep.
    cs_record = nusc.get('calibrated_sensor', pointsensor['calibrated_sensor_token'])
    pc.rotate(Quaternion(cs_record['rotation']).rotation_matrix)
    pc.translate(np.array(cs_record['translation']))

    # Second step: transform to the global frame.
    poserecord = nusc.get('ego_pose', pointsensor['ego_pose_token'])
    pc.rotate(Quaternion(poserecord['rotation']).rotation_matrix)
    pc.translate(np.array(poserecord['translation']))

    # Third step: transform into the ego vehicle frame for the timestamp of the image.
    poserecord = nusc.get('ego_pose', cam['ego_pose_token'])
    pc.translate(-np.array(poserecord['translation']))
    pc.rotate(Quaternion(poserecord['rotation']).rotation_matrix.T)

    # Fourth step: transform into the camera.
    cs_record = nusc.get('calibrated_sensor', cam['calibrated_sensor_token'])
    pc.translate(-np.array(cs_record['translation']))
    pc.rotate(Quaternion(cs_record['rotation']).rotation_matrix.T)

    # Fifth step: actually take a "picture" of the point cloud.
    # Grab the depths (camera frame z axis points away from the camera).

    # intrinsic_resized = np.matmul(camera_resize, np.array(cs_record['camera_intrinsic']))
    view = np.array(cs_record['camera_intrinsic'])
    # Take the actual picture (matrix multiplication with camera-matrix + renormalization).
    points = view_points(pc.points, view, normalize=True) #resize here

    # Resizing to target resolution
    if target_resolution[1]: # resizing width
        points[0,:] *= (target_resolution[1]/cam['width'])

    if target_resolution[0]: # resizing height
        points[1,:] *= (target_resolution[0]/cam['height'])

    # actual_resolution = (cam['height'], cam['width'])
    # for i in range(len(target_resolution)):
    #     if target_resolution[i]:
    #         points[i,:] *= (target_resolution[i]/actual_resolution[i])

    return points

def create_spatial_point_array(nusc, radar_data, pointsensor_token, camera_token, target_width=None):
    """
    This function turns a radar point cloud into a 1-D array by encoding the spatial information.
    The position in the array reflects the direction of the radar point with respect to a camera.
    :param nusc: [nuscenes.nuscenes.Nuscenes] nuScenes database
    :param target_width: [int] the target resolution along x-axis for the output array
    :param dim: dimensionality of the target array
    """
    ##########################
    ##### Initialization #####
    ##########################
    radar_meta_count = radar_data.shape[0] - 3 # -3 for substracting the image positions x y z 
    img_data = nusc.get('sample_data', camera_token)
    target_width = target_width or img_data['width']
    target_resolution = (1, target_width)
    radar_array = np.zeros((*target_resolution, radar_meta_count))

    ######################################
    ##### Perform the array creation #####
    ######################################
    # Get radar points with x and y coordinates
    projected_radar_points = map_pointcloud_to_image(nusc, radar_data, pointsensor_token=pointsensor_token, \
        camera_token=camera_token, target_resolution=target_resolution)

    for i in range(projected_radar_points.shape[1]):
        x,y = projected_radar_points[0:2,i].astype(np.int32) # first 
        if x < 0 or x >= target_width:
            continue # we skip this point, because it lies outside of the image
        y = 0 # Set height to zero in case the point is outside of the image
        radar_array[y,x] = projected_radar_points[3:,i]


    ################################
    ##### Postprocess the data #####
    ################################
    # Remove x,y,z from radar data
    # radar_array = radar_array[3:,:]

    return radar_array

def imageplus_creation(nusc, image_data, radar_data, pointsensor_token, camera_token, height=(0,3),image_target_shape=(900, 1600), clear_radar=False, clear_image=False):
    """
    Superordinate function that creates image_plus data of raw camera and radar data
    :param nusc: nuScenes initialization
    :param image_data: [numpy array] (900 x 1600 x 3)
    :param radar_data: [numpy array](SHAPE?) with radar parameter (e.g. velocity) in rows and radar points for one timestep in columns
        Semantics:
            [0]: x (1)
            [1]: y (2)
            [2]: z (3)
            [3]: dyn_prop (4)
            [4]: id (5)
            [5]: rcs (6)
            [6]: vx (7)
            [7]: vy (8)
            [8]: vx_comp (9)
            [9]: vy_comp (10)
            [10]: is_quality_valid (11)
            [11]: ambig_state (12)
            [12]: x_rms (13)
            [13]: y_rms (14)
            [14]: invalid_state (15)
            [15]: pdh0 (16)
            [16]: vx_rms (17)
            [17]: vy_rms (18)
            [18]: distance (19)
    :param pointsensor_token: [str] token of the pointsensor that should be used, most likely radar
    :param camera_token: [str] token of the camera sensor
    :param height: 2 options for 2 different modi
            a.) [tuple] (e.g. height=(0,3)) to define lower and upper boundary
            b.) [str] height = 'FOV' for calculating the heights after the field of view of the radar
    :param image_target_shape: [tuple] with (height, width), default is (900, 1600)
    :param clear_radar: [boolean] True if radar data should be all zero
    :param clear_image: [boolean] True if image data should be all zero
    :returns: [tuple] image_plus, image
        -image_plus: [numpy array] (900 x 1600 x (3 + number of radar_meta (e.g. velocity)))
           Semantics:
            [0]: R (1)
            [1]: G (2)
            [2]: B (3)
            [3]: dyn_prop (4)
            [4]: id (5)
            [5]: rcs (6)
            [6]: vx (7)
            [7]: vy (8)
            [8]: vx_comp (9)
            [9]: vy_comp (10)
            [10]: is_quality_valid (11)
            [11]: ambig_state (12)
            [12]: x_rms (13)
            [13]: y_rms (14)
            [14]: invalid_state (15)
            [15]: pdh0 (16)
            [16]: vx_rms (17)
            [17]: vy_rms (18)
            [18]: distance (19)
        -cur_image: [numpy array] the original, resized image
    """

    ###############################
    ##### Preprocess the data #####
    ###############################
    # enable barcode method
    barcode = False
    if height[1] > 20:
        height = (0,1)
        barcode = True

    # Resize the image due to a target shape
    cur_img, camera_resize = _resize_image(image_data, image_target_shape)

    # Get radar points with the desired height and radar meta data
    radar_points, radar_xyz_endpoint = _radar_transformation(radar_data, height)

    #######################
    ##### Filter Data #####
    #######################
    # Clear the image if clear_image is True
    if clear_image: 
        cur_img.fill(0)

    #####################################
    ##### Perform the actual Fusion #####
    #####################################
    # Map the radar points into the image
    radar_points = map_pointcloud_to_image(nusc, radar_points, pointsensor_token=pointsensor_token, camera_token=camera_token, target_resolution=image_target_shape)
    radar_xyz_endpoint = map_pointcloud_to_image(nusc, radar_xyz_endpoint, pointsensor_token=pointsensor_token, camera_token=camera_token, target_resolution=image_target_shape)

    if barcode:
        radar_points[1,:] = image_data.shape[0]
        radar_xyz_endpoint[1,:] = 0

    # Create image plus by creating projection lines and store them as additional channels in the image
    image_plus = _radar2camera(cur_img, radar_points, radar_xyz_endpoint, clear_radar=clear_radar)

    #########################
    ##### Quality Check #####
    #########################
    # Check if clear_image worked
    # if clear_image and np.count_nonzero(image_plus[0:3]):
    #     print("Clearing image did not work")

    return image_plus

def imageplus_creation_camra(image_data, radar_data, calibrator, height=(0,3), \
        image_target_shape=(800, 1280)):

    ratio = [image_target_shape[0] / image_data.shape[0], image_target_shape[1] / image_data.shape[1]] 
    image_data, _  = _resize_image(image_data, image_target_shape)

    image_data = image_data/255

    x,y,z = radar_data[0:3]

    ## Bottom point of projection line
    z = np.ones(x.shape) *(height[0]+0.5)

    # radar points according to world2cam convention
    radar_points = [z,y,-x]
    cam_points_low = np.array(calibrator.world2cam(radar_points))
    cam_points_low = np.array([ratio[1] * cam_points_low[0], ratio[0] * cam_points_low[1]]).astype(np.uint16)


    ## Ceiling point of projection line
    z = np.ones(x.shape) *(- height[1] +0.5)

    # radar points according to world2cam convention
    radar_points = [z,y,-x]
    cam_points_high = np.array(calibrator.world2cam(radar_points))
    cam_points_high = np.array([ratio[1] * cam_points_high[0], ratio[0] * cam_points_high[1]]).astype(np.uint16)

    # Prevent errors in projection where the high point is lower than the low point
    points_to_keep = cam_points_high[1,:] < cam_points_low[1,:]
    cam_points_high = cam_points_high[:, points_to_keep]
    cam_points_low = cam_points_low[:, points_to_keep]


    radar_meta_count = radar_data.shape[0]-3
    radar_extension = np.zeros((image_data.shape[0], image_data.shape[1], radar_meta_count), dtype=np.float32)
    no_of_points = cam_points_low.shape[1]

    for radar_point in range(0, no_of_points):
        projection_line = _create_vertical_line(
            cam_points_low[:, radar_point], cam_points_high[:, radar_point], image_data)

        for pixel_point in range(0, projection_line.shape[0]):
            y = projection_line[pixel_point, 1].astype(int)
            x = projection_line[pixel_point, 0].astype(int)

            # Check if pixel is already filled with radar data and overwrite if distance is less than the existing
            if not np.any(radar_extension[y, x]) or radar_data[-1, radar_point] < radar_extension[y, x, -1]:
                radar_extension[y, x] = radar_data[3:, radar_point]

    image_plus = np.concatenate((image_data, radar_extension), axis=2)
    return image_plus

def create_imagep_visualization(image_plus_data, color_channel="distance", \
        draw_circles=False, cfg=None, radar_lines_opacity=1.0):
    """
    Visualization of image plus data
    Parameters:
        :image_plus_data: a numpy array (900 x 1600 x (3 + number of radar_meta (e.g. velocity)))
        :image_data: a numpy array (900 x 1600 x 3)
        :color_channel: <str> Image plus channel for colorizing the radar lines. according to radar.channel_map.
        :draw_circles: Draws circles at the bottom of the radar lines
    Returns:
        :image_data: a numpy array (900 x 1600 x 3)
    """
    # read dimensions
    image_plus_height = image_plus_data.shape[0]
    image_plus_width = image_plus_data.shape[1]
    n_channels = image_plus_data.shape[2]

    ##### Extract the image Channels #####
    if cfg is None:
        image_channels = [0,1,2]
    else:
        image_channels = [i_ch for i_ch in cfg.channels if i_ch in [0,1,2]]
    image_data = np.ones(shape=(*image_plus_data.shape[:2],3))
    if len(image_channels) > 0:
        image_data[:,:,image_channels] = image_plus_data[:,:,image_channels].copy() # copy so we dont change the old image

    # Draw the Horizon
    image_data = np.array(image_data*255).astype(np.uint8)
    image_data = cv2.cvtColor(image_data, cv2.COLOR_RGB2BGR)

    ##### Paint every augmented pixel on the image #####
    if n_channels > 3:
        # transfer it to the currently selected channels
        if cfg is None:
            print("Warning, no cfg provided. Thus, its not possible to find out \
                which channel shall be used for colorization")
            radar_img = np.zeros(image_plus_data.shape[:-1]) # we expect the channel index to be the last axis
        else:
            available_channels = {radar.channel_map[ch]:ch_idx for ch_idx, ch in enumerate(cfg.channels) if ch > 2}
            ch_idx = available_channels[color_channel]
            # Normalize the radar
            if cfg.normalize_radar: # normalization happens from -127 to 127
                radar_img = image_plus_data[...,ch_idx] + 127.5
            else:
                radar_img = radar.normalize(color_channel, image_plus_data[..., ch_idx],
                                            normalization_interval=[0, 255], sigma_factor=2)

            radar_img = np.clip(radar_img,0,255)

        radar_colormap = np.array(cv2.applyColorMap(radar_img.astype(np.uint8), cv2.COLORMAP_AUTUMN))

        for x in range(0, image_plus_width):
            for y in range(0, image_plus_height):
                radar_channels = image_plus_data[y, x, 3:]
                pixel_contains_radar = np.count_nonzero(radar_channels)
                if not pixel_contains_radar:
                    continue

                radar_color = radar_colormap[y,x]
                for pixel in [(y,x)]: #[(y,x-1),(y,x),(y,x+1)]:
                    if image_data.shape > pixel:

                        # Calculate the color
                        pixel_color = np.array(image_data[pixel][0:3], dtype=np.uint8)
                        pixel_color = np.squeeze(cv2.addWeighted(pixel_color, 1-radar_lines_opacity, radar_color, radar_lines_opacity, 0))

                        # Draw on image
                        image_data[pixel] = pixel_color

                # only if some radar information is there
                if draw_circles:
                    if image_plus_data.shape[0] > y+1 and not np.any(image_plus_data[y+1, x,3:]):
                        cv2.circle(image_data, (x,y), 3, color=radar_colormap[(y,x)].astype(np.float), thickness=1)


    return image_data

## CRFNet model
The model is based on RetinaNet, which is composed of three main parts:

1. Feature Extractor (VGG16)
2. Feature Pyramid Network (FPN)
3. Regression head and Classification head

Different with a classic RetinaNet, CRFNet has the feature map fused with rader signal in every VGG block and FPN output

In [5]:
# Constructing the Feature extractor (backbone) class based on VGG16
class Backbone(nn.Module):
    
    def __init__(self, in_channels = 5):
        
        super(Backbone, self).__init__()
        
        self.blocks : List[nn.Sequential] = []
        self.in_channels = in_channels
        for i in range(5):
            blockmodules = []
            if i < 4:
                blockmodules.append(nn.Conv2d(self.in_channels, (2**i)*64, (3,3), (1,1), (1,1)))
                blockmodules.append(nn.ReLU(inplace=True))
                blockmodules.append(nn.Conv2d((2**i)*64, (2**i)*64, (3,3), (1,1), (1,1)))
                blockmodules.append(nn.ReLU(inplace=True))
            
            if 4 > i > 1:
                blockmodules.append(nn.Conv2d((2**i)*64, (2**i)*64, (3,3), (1,1), (1,1)))
                blockmodules.append(nn.ReLU(inplace=True))
            
            if i == 4:
                blockmodules.append(nn.Conv2d(514, 512, (3,3), (1,1), (1,1)))
                blockmodules.append(nn.ReLU(inplace=True))
                blockmodules.append(nn.Conv2d(512, 512, (3,3), (1,1), (1,1)))
                blockmodules.append(nn.ReLU(inplace=True))
                blockmodules.append(nn.Conv2d(512, 512, (3,3), (1,1), (1,1)))
                blockmodules.append(nn.ReLU(inplace=True))
            
            blockmodules.append(nn.MaxPool2d(2, 2, 0, 1, ceil_mode=True))
            
            sequence = nn.Sequential(*blockmodules)
            #print(sequence)
            self.blocks.append(sequence)
            
            # plus 2 for radar
            self.in_channels = (2**i)*64+2

        #print(self.blocks)
        self.max_pool = nn.MaxPool2d(2, stride=2,ceil_mode=True)
    
       
        #Initialize dictionary that holds the output data
        output_dict = {}
        
        

    def forward(self, inputs):
        if not torch.is_tensor(inputs):
            inputs = torch.from_numpy(inputs)
        output_dict = {}
        #Split the camera and radar data
        radar_data = inputs[:,3:,:,:]
        camera_data = inputs[:,:3,:,:]
        #For the first 5 blocks:
        for i in range(5):
            #Concatenate camera and radar data
            combined_data = torch.cat((camera_data, radar_data), 1)
            #Apply the max pooling layer to the radar data            
            radar_data = self.max_pool(radar_data)
            #Apply the VGG16 block to the combined data
            camera_data = self.blocks[i](combined_data)
            #For blocks 3-5 add the radar and camera data to the output
            if i > 1:
                vgg_key = "vgg_output_{}".format(i+1)
                output_dict[vgg_key] = torch.cat((camera_data, radar_data), 1)
                rad_key = "rad_output_{}".format(i+1)
                output_dict[rad_key] = radar_data
        
        #Apply the final two max pool layers to the radar data and save to output

        radar_data = self.max_pool(radar_data)
        output_dict['rad_output_6'] = radar_data
        
        radar_data = self.max_pool(radar_data)
        output_dict['rad_output_7'] = radar_data
        
        return output_dict

In [6]:
# Constructing the Feature Pyramid Network (FPN) class and the detection heads
# FPN
class PyramidFeatures(nn.Module):
    def __init__(self, C3_size=258, C4_size=514, C5_size=514, feature_size=254):
        super(PyramidFeatures, self).__init__()

        # upsample C5 to get P5 from the FPN paper
        self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        # add P5 elementwise to C4
        self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        # add P4 elementwise to C3
        self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
        self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

        # "P6 is obtained via a 3x3 stride-2 conv on C5"
        self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)

        # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6"
        self.P7_1 = nn.ReLU()
        self.P7_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)
        
    def forward(self, inputs_dict, use_radar=True):
        C3 = inputs_dict['vgg_output_3'].to(device) 
        C4 = inputs_dict['vgg_output_4'].to(device)
        C5 = inputs_dict['vgg_output_5'].to(device)
        R3 = inputs_dict['rad_output_3'].to(device)
        R4 = inputs_dict['rad_output_4'].to(device)
        R5 = inputs_dict['rad_output_5'].to(device)
        R6 = inputs_dict['rad_output_6'].to(device)
        R7 = inputs_dict['rad_output_7'].to(device)

        P5_x = self.P5_1(C5)
        P5_x = self.P5_2(P5_x)

        P4_x = self.P4_1(C4)
        P5_upsampled_x = F.interpolate(P5_x,size=[P4_x.shape[2],P4_x.shape[3]], mode='nearest')
        P4_x = P5_upsampled_x + P4_x
        P4_x = self.P4_2(P4_x)
        
        P3_x = self.P3_1(C3)
        P4_upsampled_x = F.interpolate(P4_x,size=[P3_x.shape[2],P3_x.shape[3]], mode='nearest')
        P3_x = P3_x + P4_upsampled_x
        P3_x = self.P3_2(P3_x)

        P6_x = self.P6(C5)

        P7_x = self.P7_1(P6_x)
        P7_x = self.P7_2(P7_x)
       
        #print(P3_x.shape,P4_x.shape,P5_x.shape,P6_x.shape,P7_x.shape)
        #print(R3.shape,R4.shape,R5.shape,R6.shape,R7.shape)
        if use_radar:
            P3_x = torch.cat((P3_x, R3),1)
            P4_x = torch.cat((P4_x, R4),1)
            P5_x = torch.cat((P5_x, R5),1)
            P6_x = torch.cat((P6_x, R6),1)
            P7_x = torch.cat((P7_x, R7),1)
            return [P3_x, P4_x, P5_x, P6_x, P7_x]
        else:         
            return [P3_x, P4_x, P5_x, P6_x, P7_x]

# Detection Heads
class RegressionModel(nn.Module):
    def __init__(self, num_features_in=256, num_anchors=9, feature_size=256):
        super(RegressionModel, self).__init__()

        # num_features_in: channels of [P3_x, P4_x, P5_x, P6_x, P7_x]
        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchors * 4, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.conv3(out)
        out = self.act3(out)

        out = self.conv4(out)
        out = self.act4(out)

        out = self.output(out)

        # out is B x C x W x H, with C = 4*num_anchors
        out = out.permute(0, 2, 3, 1)

        return out.contiguous().view(out.shape[0], -1, 4)   
    
class ClassificationModel(nn.Module):
    def __init__(self, num_features_in=256, num_anchors=9, num_classes=8, prior=0.01, feature_size=256):
        super(ClassificationModel, self).__init__()

        self.num_classes = num_classes
        self.num_anchors = num_anchors

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, padding=1)
        self.output_act = nn.Sigmoid()

    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.conv3(out)
        out = self.act3(out)

        out = self.conv4(out)
        out = self.act4(out)

        out = self.output(out)
        out = self.output_act(out)

        # out is B x C x W x H, with C = n_classes + n_anchors
        out1 = out.permute(0, 2, 3, 1)

        batch_size, width, height, channels = out1.shape

        out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes)

        return out2.contiguous().view(x.shape[0], -1, self.num_classes)

In [7]:
# Constructing the Anchors class to generate anchors for features                                                      
class RetinaAnchors(nn.Module):
    def __init__(self, areas, ratios, scales, strides):
        super(RetinaAnchors, self).__init__()
        self.areas = areas
        self.ratios = ratios
        self.scales = scales
        self.strides = strides

    def forward(self, batch_size, fpn_feature_sizes):
        """
        generate batch anchors
        """
        device = fpn_feature_sizes.device
        one_sample_anchors = []
        for index, area in enumerate(self.areas):
            base_anchors = self.generate_base_anchors(area, self.scales,self.ratios)
            featrue_anchors = self.generate_anchors_on_feature_map(base_anchors, fpn_feature_sizes[index], self.strides[index])
            featrue_anchors = featrue_anchors.to(device)
            one_sample_anchors.append(featrue_anchors)

        batch_anchors = []
        for per_level_featrue_anchors in one_sample_anchors:
            per_level_featrue_anchors = per_level_featrue_anchors.unsqueeze(
                0).repeat(batch_size, 1, 1)
            batch_anchors.append(per_level_featrue_anchors)

        # if input size:[B,3,640,640]
        # batch_anchors shape:[[B, 57600, 4],[B, 14400, 4],[B, 3600, 4],[B, 900, 4],[B, 225, 4]]
        # per anchor format:[x_min,y_min,x_max,y_max]
        return batch_anchors

    def generate_base_anchors(self, area, scales, ratios):
        """
        generate base anchor
        """
        # get w,h aspect ratio,shape:[9,2]
        aspects = torch.tensor([[[s * math.sqrt(r), s * math.sqrt(1 / r)]
                                 for s in scales]
                                for r in ratios]).view(-1, 2)
        # base anchor for each position on feature map,shape[9,4]
        base_anchors = torch.zeros((len(scales) * len(ratios), 4))

        # compute aspect w\h,shape[9,2]
        base_w_h = area * aspects
        base_anchors[:, 2:] += base_w_h

        # base_anchors format: [x_min,y_min,x_max,y_max],center point:[0,0],shape[9,4]
        base_anchors[:, 0] -= base_anchors[:, 2] / 2
        base_anchors[:, 1] -= base_anchors[:, 3] / 2
        base_anchors[:, 2] /= 2
        base_anchors[:, 3] /= 2

        return base_anchors

    def generate_anchors_on_feature_map(self, base_anchors, feature_map_size,
                                        stride):
        """
        generate all anchors on a feature map
        """
        # shifts_x shape:[w],shifts_x shape:[h]
        shifts_x = (torch.arange(0, feature_map_size[0]) + 0.5) * stride
        shifts_y = (torch.arange(0, feature_map_size[1]) + 0.5) * stride

        # shifts shape:[w,h,2] -> [w,h,4] -> [w,h,1,4]
        shifts = torch.tensor([[[shift_x, shift_y] for shift_y in shifts_y]
                               for shift_x in shifts_x]).repeat(1, 1,
                                                                2).unsqueeze(2)

        # base anchors shape:[9,4] -> [1,1,9,4]
        base_anchors = base_anchors.unsqueeze(0).unsqueeze(0)
        # generate all featrue map anchors on each feature map points
        # featrue map anchors shape:[w,h,9,4] -> [h,w,9,4] -> [h*w*9,4]
        feature_map_anchors = (base_anchors + shifts).permute(
            1, 0, 2, 3).contiguous().view(-1, 4)

        # feature_map_anchors format: [anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        return feature_map_anchors

In [8]:
# Constructing the CRF net class, which combines the networks 
class CRFNet(nn.Module):
    def __init__(self,num_class=8):
        super(CRFNet, self).__init__()
        
        self.backbone = Backbone().to(device) 
        self.fpn = PyramidFeatures().to(device)
        self.regression = RegressionModel().to(device)
        self.distance = RegressionModel().to(device)
        self.classification = ClassificationModel().to(device)
        
        # Setting anchors
        self.areas = torch.tensor([[16, 16], [32, 32], [64, 64], [128, 128], [256, 256]])
        self.ratios = torch.tensor([0.5, 1, 2])
        self.scales = torch.tensor([2**0, 2**(1.0 / 3.0), 2**(2.0 / 3.0)])
        self.strides = torch.tensor([8, 16, 32, 64, 128], dtype=torch.float)
        self.anchors = RetinaAnchors(self.areas, self.ratios, self.scales, self.strides)
    
    def forward(self, inputs):
        self.batch_size, _, _, _ = inputs.shape
        device = inputs.device
        
        # Backbone, input:image+radar, output: C3, C4, C5, R3, R4, R5, R6, R7 
        backbone_output = self.backbone.forward(inputs)
        
        # FPN, input: C3, C4, C5, R3, R4, R5, R6, R7 , output: P3, P4, P5, P6, P7 (features)
        features = self.fpn.forward(backbone_output)
        
        # Detection Head, input: P4, P5, P6, P7, output:
        regression = torch.cat([self.regression.forward(feature) for feature in features], dim=1)
        classification = torch.cat([self.classification.forward(feature) for feature in features], dim=1)
        distance = torch.cat([self.distance.forward(feature) for feature in features], dim=1)
        
        # Anchors
        self.fpn_feature_sizes = []
        for feature in features:
            self.fpn_feature_sizes.append([feature.shape[3], feature.shape[2]])
        self.fpn_feature_sizes = torch.tensor(self.fpn_feature_sizes).to(device)
        #print(self.fpn_feature_sizes)
        batch_anchors = self.anchors(self.batch_size, self.fpn_feature_sizes)
        batch_anchors = torch.cat([batch for batch in batch_anchors],dim=1)
        
        return regression, classification, distance, batch_anchors

## Loss and prediction
Besides the model itself, several other components need to be implemented 

1. The focal loss function for classification head result and smoothe-L1 loss for regression head output
2. The decoder based on NMS(non-maximum supression), which transform the output of detection heads to predication boxes
3. The metrix used to measure the predication performance: mAP

In [9]:
# Construting the focal loss and regression loss functions 
class RetinaLoss(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 alpha=0.25,
                 gamma=2,
                 beta=1.0 / 9.0,
                 epsilon=1e-4):
        super(RetinaLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.beta = beta
        self.epsilon = epsilon
        self.image_w = image_w
        self.image_h = image_h

    def forward(self, cls_heads, reg_heads, batch_anchors, annotations):
        """
        compute cls loss and reg loss in one batch
        """
        #device = annotations.device
        #cls_heads = torch.cat(cls_heads, axis=1)
        #reg_heads = torch.cat(reg_heads, axis=1)
        #batch_anchors = torch.cat(batch_anchors, axis=1)

        cls_heads, reg_heads, batch_anchors = self.drop_out_border_anchors_and_heads(
            cls_heads, reg_heads, batch_anchors, self.image_w, self.image_h)
        #print(cls_heads.shape, reg_heads.shape, batch_anchors.shape)
        batch_anchors_annotations = self.get_batch_anchors_annotations(batch_anchors, annotations)
        #print(batch_anchors_annotations.shape)

        cls_loss, reg_loss = [], []
        valid_image_num = 0
        for per_image_cls_heads, per_image_reg_heads, per_image_anchors_annotations in zip(
                cls_heads, reg_heads, batch_anchors_annotations):
            # valid anchors contain all positive anchors
            valid_anchors_num = (per_image_anchors_annotations[
                per_image_anchors_annotations[:, 4] > 0]).shape[0]
            #print("valid_anchors_num:", valid_anchors_num)
            #print(per_image_anchors_annotations[per_image_anchors_annotations[:, 4] == 0].shape)
            if valid_anchors_num == 0:
                cls_loss.append(torch.tensor(0.).to(device))
                reg_loss.append(torch.tensor(0.).to(device))
            else:
                valid_image_num += 1
                one_image_cls_loss = self.compute_one_image_focal_loss(
                    per_image_cls_heads, per_image_anchors_annotations)
                one_image_reg_loss = self.compute_one_image_smoothl1_loss(
                    per_image_reg_heads, per_image_anchors_annotations)
                cls_loss.append(one_image_cls_loss)
                reg_loss.append(one_image_reg_loss)
        
        if valid_image_num==0:
            return cls_loss[0], reg_loss[0]
        
        cls_loss = sum(cls_loss) / valid_image_num
        reg_loss = sum(reg_loss) / valid_image_num

        return cls_loss, reg_loss

    def compute_one_image_focal_loss(self, per_image_cls_heads,
                                     per_image_anchors_annotations):
        """
        compute one image focal loss(cls loss)
        per_image_cls_heads:[anchor_num,num_classes]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate focal loss
        per_image_cls_heads = per_image_cls_heads[per_image_anchors_annotations[:, 4] >= 0]
           
        per_image_anchors_annotations = per_image_anchors_annotations[per_image_anchors_annotations[:, 4] >= 0]
            
        per_image_cls_heads = torch.clamp(per_image_cls_heads,
                                          min=self.epsilon,
                                          max=1. - self.epsilon)
        num_classes = per_image_cls_heads.shape[1]

        # generate 8 binary ground truth classes for each anchor
        loss_ground_truth = F.one_hot(per_image_anchors_annotations[:,4].long(),num_classes=num_classes + 1)
        loss_ground_truth = loss_ground_truth[:, 1:]
        loss_ground_truth = loss_ground_truth.float().to(device)

        alpha_factor = torch.ones_like(per_image_cls_heads) * self.alpha
        alpha_factor = torch.where(torch.eq(loss_ground_truth, 1.),alpha_factor, 1. - alpha_factor)
        
        pt = torch.where(torch.eq(loss_ground_truth, 1.), per_image_cls_heads, 1. - per_image_cls_heads)
        focal_weight = alpha_factor * torch.pow((1. - pt), self.gamma)

        bce_loss = -(loss_ground_truth * torch.log(per_image_cls_heads) +(1. - loss_ground_truth) * torch.log(1. - per_image_cls_heads))

        one_image_focal_loss = focal_weight * bce_loss

        one_image_focal_loss = one_image_focal_loss.sum()
        positive_anchors_num = per_image_anchors_annotations[per_image_anchors_annotations[:, 4] > 0].shape[0]
        
        # according to the original paper,We divide the focal loss by the number of positive sample anchors
        one_image_focal_loss = one_image_focal_loss / positive_anchors_num

        return one_image_focal_loss

    def compute_one_image_smoothl1_loss(self, per_image_reg_heads,
                                        per_image_anchors_annotations):
        """
        compute one image smoothl1 loss(reg loss)
        per_image_reg_heads:[anchor_num,4]
        per_image_anchors_annotations:[anchor_num,5]
        """
        # Filter anchors with gt class=-1, this part of anchor doesn't calculate smoothl1 loss
        #device = per_image_reg_heads.device
        per_image_reg_heads = per_image_reg_heads[
            per_image_anchors_annotations[:, 4] > 0]
        per_image_anchors_annotations = per_image_anchors_annotations[
            per_image_anchors_annotations[:, 4] > 0]
        positive_anchor_num = per_image_anchors_annotations.shape[0]

        if positive_anchor_num == 0:
            return torch.tensor(0.).to(device)

        # compute smoothl1 loss
        loss_ground_truth = per_image_anchors_annotations[:, 0:4]
        x = torch.abs(per_image_reg_heads - loss_ground_truth)
        one_image_smoothl1_loss = torch.where(torch.ge(x, self.beta),
                                              x - 0.5 * self.beta,
                                              0.5 * (x**2) / self.beta)
        one_image_smoothl1_loss = one_image_smoothl1_loss.mean(axis=1).sum()
        # according to the original paper,We divide the smoothl1 loss by the number of positive sample anchors
        one_image_smoothl1_loss = one_image_smoothl1_loss / positive_anchor_num

        return one_image_smoothl1_loss

    def drop_out_border_anchors_and_heads(self, cls_heads, reg_heads,
                                          batch_anchors, image_w, image_h):
        """
        dropout out of border anchors,cls heads and reg heads
        """
        final_cls_heads, final_reg_heads, final_batch_anchors = [], [], []
        for per_image_cls_head, per_image_reg_head, per_image_anchors in zip(
                cls_heads, reg_heads, batch_anchors):
            per_image_cls_head = per_image_cls_head[per_image_anchors[:,
                                                                      0] > 0.0]
            per_image_reg_head = per_image_reg_head[per_image_anchors[:,
                                                                      0] > 0.0]
            per_image_anchors = per_image_anchors[per_image_anchors[:,
                                                                    0] > 0.0]

            per_image_cls_head = per_image_cls_head[per_image_anchors[:,
                                                                      1] > 0.0]
            per_image_reg_head = per_image_reg_head[per_image_anchors[:,
                                                                      1] > 0.0]
            per_image_anchors = per_image_anchors[per_image_anchors[:,
                                                                    1] > 0.0]

            per_image_cls_head = per_image_cls_head[
                per_image_anchors[:, 2] < image_w]
            per_image_reg_head = per_image_reg_head[
                per_image_anchors[:, 2] < image_w]
            per_image_anchors = per_image_anchors[
                per_image_anchors[:, 2] < image_w]

            per_image_cls_head = per_image_cls_head[
                per_image_anchors[:, 3] < image_h]
            per_image_reg_head = per_image_reg_head[
                per_image_anchors[:, 3] < image_h]
            per_image_anchors = per_image_anchors[
                per_image_anchors[:, 3] < image_h]

            per_image_cls_head = per_image_cls_head.unsqueeze(0)
            per_image_reg_head = per_image_reg_head.unsqueeze(0)
            per_image_anchors = per_image_anchors.unsqueeze(0)

            final_cls_heads.append(per_image_cls_head)
            final_reg_heads.append(per_image_reg_head)
            final_batch_anchors.append(per_image_anchors)

        final_cls_heads = torch.cat(final_cls_heads, axis=0)
        final_reg_heads = torch.cat(final_reg_heads, axis=0)
        final_batch_anchors = torch.cat(final_batch_anchors, axis=0)

        # final cls heads shape:[batch_size, anchor_nums, class_num]
        # final reg heads shape:[batch_size, anchor_nums, 4]
        # final batch anchors shape:[batch_size, anchor_nums, 4]
        return final_cls_heads, final_reg_heads, final_batch_anchors

    def get_batch_anchors_annotations(self, batch_anchors, annotations):
        """
        Assign a ground truth box target and a ground truth class target for each anchor
        if anchor gt_class index = -1,this anchor doesn't calculate cls loss and reg loss
        if anchor gt_class index = 0,this anchor is a background class anchor and used in calculate cls loss
        if anchor gt_class index > 0,this anchor is a object class anchor and used in calculate cls loss and reg loss
        """
        #device = annotations.device
        #assert batch_anchors.shape[0] == annotations.shape[0]
        one_image_anchor_nums = batch_anchors.shape[1]
        batch_anchors_annotations = []
        for one_image_anchors, one_image_annotations in zip(
                batch_anchors, annotations):
            # drop all index=-1 class annotations
            one_image_annotations = one_image_annotations[
                one_image_annotations[:, 4] >= 0]
            #print(one_image_annotations.shape)

            if one_image_annotations.shape[0] == 0:
                one_image_anchor_annotations = torch.ones(
                    [one_image_anchor_nums, 5], device=device) * (-1)
            else:
                one_image_gt_bboxes = one_image_annotations[:, 0:4]
                one_image_gt_class = one_image_annotations[:, 4].to(device)
                one_image_ious = self.compute_ious_for_one_image(one_image_anchors, one_image_gt_bboxes)

                # snap per gt bboxes to the best iou anchor
                overlap, indices = one_image_ious.max(axis=1)
                # assgin each anchor gt bboxes for max iou annotation
                per_image_anchors_gt_bboxes = one_image_gt_bboxes[indices]
                # transform gt bboxes to [tx,ty,tw,th] format for each anchor
                one_image_anchors_snaped_boxes = self.snap_annotations_as_tx_ty_tw_th(
                    per_image_anchors_gt_bboxes, one_image_anchors)

                one_image_anchors_gt_class = (torch.ones_like(overlap) *-1).to(device)
                # if iou <0.4,assign anchors gt class as 0:background
                one_image_anchors_gt_class[overlap < 0.4] = 0
                # if iou >=0.5,assign anchors gt class as same as the max iou annotation class:80 classes index from 1 to 80
                one_image_anchors_gt_class[overlap >=0.5] = one_image_gt_class[indices][overlap >= 0.5] + 1

                one_image_anchors_gt_class = one_image_anchors_gt_class.unsqueeze(-1)

                one_image_anchor_annotations = torch.cat([one_image_anchors_snaped_boxes, one_image_anchors_gt_class], axis=1)
            
            one_image_anchor_annotations = one_image_anchor_annotations.unsqueeze(0)
            batch_anchors_annotations.append(one_image_anchor_annotations)

        batch_anchors_annotations = torch.cat(batch_anchors_annotations,axis=0)

        # batch anchors annotations shape:[batch_size, anchor_nums, 5]
        return batch_anchors_annotations

    def snap_annotations_as_tx_ty_tw_th(self, anchors_gt_bboxes, anchors):
        """
        snap each anchor ground truth bbox form format:[x_min,y_min,x_max,y_max] to format:[tx,ty,tw,th]
        """
        anchors_w_h = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_w_h

        anchors_gt_bboxes_w_h = anchors_gt_bboxes[:,
                                                  2:] - anchors_gt_bboxes[:, :2]
        anchors_gt_bboxes_w_h = torch.clamp(anchors_gt_bboxes_w_h, min=1.0)
        anchors_gt_bboxes_ctr = anchors_gt_bboxes[:, :
                                                  2] + 0.5 * anchors_gt_bboxes_w_h

        snaped_annotations_for_anchors = torch.cat(
            [(anchors_gt_bboxes_ctr - anchors_ctr) / anchors_w_h,
             torch.log(anchors_gt_bboxes_w_h / anchors_w_h)],
            axis=1).to(device)
        #device = snaped_annotations_for_anchors.device
        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)

        snaped_annotations_for_anchors = snaped_annotations_for_anchors / factor

        # snaped_annotations_for_anchors shape:[batch_size, anchor_nums, 4]
        return snaped_annotations_for_anchors

    def compute_ious_for_one_image(self, one_image_anchors,
                                   one_image_annotations):
        """
        compute ious between one image anchors and one image annotations
        """
        # make sure anchors format:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        # make sure annotations format: [annotation_nums,4],4:[x_min,y_min,x_max,y_max]
        annotation_num = one_image_annotations.shape[0]

        one_image_ious = []
        for annotation_index in range(annotation_num):
            annotation = one_image_annotations[
                annotation_index:annotation_index + 1, :]
            overlap_area_top_left = torch.max(one_image_anchors[:, :2],
                                              annotation[:, :2])
            overlap_area_bot_right = torch.min(one_image_anchors[:, 2:],
                                               annotation[:, 2:])
            overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                             overlap_area_top_left,
                                             min=0)
            overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:, 1]
            # anchors and annotations convert format to [x1,y1,w,h]
            anchors_w_h = one_image_anchors[:,
                                            2:] - one_image_anchors[:, :2] + 1
            annotations_w_h = annotation[:, 2:] - annotation[:, :2] + 1
            # compute anchors_area and annotations_area
            anchors_area = anchors_w_h[:, 0] * anchors_w_h[:, 1]
            annotations_area = annotations_w_h[:, 0] * annotations_w_h[:, 1]

            # compute union_area
            union_area = anchors_area + annotations_area - overlap_area
            union_area = torch.clamp(union_area, min=1e-4)
            # compute ious between one image anchors and one image annotations
            ious = (overlap_area / union_area).unsqueeze(-1)

            one_image_ious.append(ious)

        one_image_ious = torch.cat(one_image_ious, axis=1)

        # one image ious shape:[anchors_num,annotation_num]
        return one_image_ious

In [10]:
# Constructing the decoder for transforming the output to prediction
class RetinaDecoder(nn.Module):
    def __init__(self,image_w,image_h,
                 min_score_threshold=0.05,
                 nms_threshold=0.5,
                 max_detection_num=300):
        super(RetinaDecoder, self).__init__()
        self.image_w = image_w
        self.image_h = image_h
        self.min_score_threshold = min_score_threshold
        self.nms_threshold = nms_threshold
        self.max_detection_num = max_detection_num

    def forward(self, cls_heads, reg_heads, batch_anchors):
        with torch.no_grad():
            device = cls_heads[0].device

            batch_scores, batch_classes, batch_pred_bboxes = [], [], []
            for per_image_cls_heads, per_image_reg_heads, per_image_anchors in zip(cls_heads, reg_heads, batch_anchors):
                # transfer reg head
                pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(per_image_reg_heads, per_image_anchors)
               
                # score of classification heads 
                scores, score_classes = torch.max(per_image_cls_heads, dim=1)
                score_classes = score_classes[scores > self.min_score_threshold].float()    
                pred_bboxes = pred_bboxes[ scores > self.min_score_threshold].float()    
                scores = scores[scores > self.min_score_threshold].float()

                single_image_scores = (-1) * torch.ones((self.max_detection_num, ), device=device)
                single_image_classes = (-1) * torch.ones((self.max_detection_num, ), device=device)
                single_image_pred_bboxes = (-1) * torch.ones((self.max_detection_num, 4), device=device)

                if scores.shape[0] != 0:
                    scores, score_classes, pred_bboxes = self.nms(scores, score_classes, pred_bboxes)
                    #print(scores)
                    sorted_keep_scores, sorted_keep_scores_indexes = torch.sort(scores, descending=True)
                    sorted_keep_classes = score_classes[sorted_keep_scores_indexes]
                    sorted_keep_pred_bboxes = pred_bboxes[sorted_keep_scores_indexes]

                    final_detection_num = min(self.max_detection_num,sorted_keep_scores.shape[0])

                    single_image_scores[0:final_detection_num] = sorted_keep_scores[0:final_detection_num]
                    single_image_classes[0:final_detection_num] = sorted_keep_classes[0:final_detection_num]
                    single_image_pred_bboxes[0:final_detection_num, :] = sorted_keep_pred_bboxes[0:final_detection_num, :]

                single_image_scores = single_image_scores.unsqueeze(0)
                single_image_classes = single_image_classes.unsqueeze(0)
                single_image_pred_bboxes = single_image_pred_bboxes.unsqueeze(0)

                batch_scores.append(single_image_scores)
                batch_classes.append(single_image_classes)
                batch_pred_bboxes.append(single_image_pred_bboxes)

            batch_scores = torch.cat(batch_scores, axis=0)
            batch_classes = torch.cat(batch_classes, axis=0)
            batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

            # batch_scores shape:[batch_size,max_detection_num]
            # batch_classes shape:[batch_size,max_detection_num]
            # batch_pred_bboxes shape[batch_size,max_detection_num,4]
            box = torch.cat([batch_pred_bboxes[batch_pred_bboxes>=0].reshape(-1,4),batch_classes[batch_classes>=0].reshape(-1,1)],dim=1) 
            
            return batch_scores, batch_classes, batch_pred_bboxes, box

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,2:] - sorted_one_image_pred_bboxes[:, :2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,0] * sorted_pred_bboxes_w_h[:,1]                                                                          
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[1:]
                    

                overlap_area_top_left = torch.max( single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                   
                overlap_area_bot_right = torch.min( single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                   
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -overlap_area_top_left,  min=0)
                                                 
                                               
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,1]
                                                                             

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                single_class_scores = single_class_scores[ious < self.nms_threshold]
                single_class = single_class[ious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[ious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[ious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes, axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

    def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(self, reg_heads, anchors):
        """
        snap reg heads to pred bboxes
        reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
        anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        """
        #Regress box
        anchors_wh = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh

        device = anchors.device
        factor = torch.tensor([[0.2, 0.2, 0.2, 0.2]]).to(device)

        reg_heads = reg_heads * factor

        pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
        pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr

        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh

        pred_bboxes = torch.cat([pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
        #pred_bboxes = pred_bboxes.int()
        
        #clamp box
        pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
        pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
        pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],max=self.image_w - 1)
        pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],max=self.image_h - 1)

        # pred bboxes shape:[anchor_nums,4]
        return pred_bboxes

In [11]:
# mAP calculation
'''
#N = number of samples
#M = number of detections
#P = number of annots
#D = data (bbox and label)
#num_categories = 8 #number of categories that can be detected
#iou_range = np.linspace(0.5, 0.95, 10)

#detections = [] #List of detections, with size [N * [M, D]]
#annots = []     #List of annots, with size [N * [P, D]]

'''

def overlap_calc(detections_bbox, annots_bbox):
    """calculates the overlap between a detection Bbox and an annotation Bbox """
    
    #for detections an annots size [D]
    
    x_overlap = max(0, (min(detections_bbox[2], annots_bbox[2]) - max(detections_bbox[0], annots_bbox[0])))
    y_overlap = max(0, (min(detections_bbox[3], annots_bbox[3]) - max(detections_bbox[1], annots_bbox[1])))
    overlap_area = x_overlap * y_overlap;
    
    annots_area = (annots_bbox[3] - annots_bbox[1]) * (annots_bbox[2] - annots_bbox[0])       
    
    overlap = overlap_area/annots_area
    
    return overlap

def precision_recall(detections, annots, iou, category):
    """Calculates the precision and recall for a specific iou value """

    TP = 0
    FP = 0
    overlap = 0
    num_annotation = 0
    #for every sample with detections of the label
    for n in range(len(detections)):
        flag = 0;
        #for every detection
        for i in range(detections[n].shape[0]):
            
            if int(detections[n][i, 4].item()) == category:
                max_overlap = 0
                #Compare the detection with all the annots of the sample
                for j in range(annots[n].shape[0]):
                    if int(annots[n][j, 4].item()) == category:
                        overlap = overlap_calc(detections[n][i,:], annots[n][j,:])
                        max_overlap = max(max_overlap, overlap)
                        if flag == 0:
                            num_annotation+=1             
                flag = 1
                if max_overlap > iou:
                    TP += 1
                else:
                    FP += 1
    
    if TP+FP==0:
        precision = 0
    else:
        precision = TP/(TP+FP)
    #calculate 
    if num_annotation==0:
        recall = 0
    else:
        recall = TP/(num_annotation)
    
    return precision, recall
    
    
def AP_calc(detections, annots, iou_range, category):
    """Calculates Average Precision (AP) for a specific label(category)"""
    
    precision_list = []
    recall_list = []
    
    for i in range(len(iou_range)):
        precision, recall = precision_recall(detections, annots, iou_range[i], category)
        precision_list.append(precision)
        recall_list.append(recall)
    
    #print(precision_list,recall_list)

    #recall_list = np.concatenate(([0.], recall_list, [1.]))
    #precision_list = np.concatenate(([0.], precision_list, [0.]))

    #AP = np.trapz(precision_list, recall_list)
    # correct AP calculation
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], recall_list, [1.]))
    mpre = np.concatenate(([0.], precision_list, [0.]))

    # compute the precision envelope
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

    # to calculate area under PR curve, look for points
    # where X axis (recall) changes value
    i = np.where(mrec[1:] != mrec[:-1])[0]

    # and sum (\Delta recall) * prec
    AP = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return AP

def total_instances_calc(annots, num_categories):
    
    total_instances = np.zeros(num_categories)
    
    #For every sample
    for i in range(len(annots)):
        #For every annotation in the sample
        for j in range(annots[i].shape[0]):
            total_instances[int(annots[i][j, 4].item())] += 1
            
    return  total_instances

def mAP_calc(detections, annots, iou_range, num_categories):
    """Calculates the mean Average Precision (mAP)"""
    
    AP_list = np.empty(num_categories)
    
    #For each seperate category
    for i in range(num_categories):
        #Calculate the AP and append it to the AP_list
        AP = AP_calc(detections, annots, iou_range,i) 
        AP_list[i] = AP
    
    #Amount of occurences of each category    
    total_instances = total_instances_calc(annots, num_categories)
        
    #print(total_instances, AP_list)
    mAP = np.sum(AP_list*total_instances) / sum(total_instances)
    
    return mAP
    

## Testing of components

In [12]:
# crf test
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
crf_test = CRFNet().to(device)

pretrained = True
#print(crf_test)
if pretrained == True:
    # Load pretrained weight for model 
    import h5py

    filename = "crf_net.h5"
    model_order = [crf_test.fpn.P3_1, crf_test.fpn.P3_1,crf_test.fpn.P4_1, crf_test.fpn.P4_1,crf_test.fpn.P5_1, crf_test.fpn.P5_1,
                   crf_test.fpn.P3_2, crf_test.fpn.P3_2,crf_test.fpn.P4_2, crf_test.fpn.P4_2,crf_test.fpn.P5_2, crf_test.fpn.P5_2,
                   crf_test.fpn.P6, crf_test.fpn.P6,crf_test.fpn.P7_2, crf_test.fpn.P7_2, 
                   crf_test.backbone.blocks[0][0],crf_test.backbone.blocks[0][0],crf_test.backbone.blocks[0][2],crf_test.backbone.blocks[0][2],
                   crf_test.backbone.blocks[1][0],crf_test.backbone.blocks[1][0],crf_test.backbone.blocks[1][2],crf_test.backbone.blocks[1][2],
                   crf_test.backbone.blocks[2][0],crf_test.backbone.blocks[2][0],crf_test.backbone.blocks[2][2],crf_test.backbone.blocks[2][2],crf_test.backbone.blocks[2][4],crf_test.backbone.blocks[2][4],
                   crf_test.backbone.blocks[3][0],crf_test.backbone.blocks[3][0],crf_test.backbone.blocks[3][2],crf_test.backbone.blocks[3][2],crf_test.backbone.blocks[3][4],crf_test.backbone.blocks[3][4],
                   crf_test.backbone.blocks[4][0],crf_test.backbone.blocks[4][0],crf_test.backbone.blocks[4][2],crf_test.backbone.blocks[4][2],crf_test.backbone.blocks[4][4],crf_test.backbone.blocks[4][4],
                   crf_test.classification.output,crf_test.classification.output,
                   crf_test.classification.conv1,crf_test.classification.conv1,
                   crf_test.classification.conv2,crf_test.classification.conv2,
                   crf_test.classification.conv3,crf_test.classification.conv3,
                   crf_test.classification.conv4,crf_test.classification.conv4,
                   crf_test.regression.output,crf_test.regression.output,
                   crf_test.regression.conv1,crf_test.regression.conv1,
                   crf_test.regression.conv2,crf_test.regression.conv2,
                   crf_test.regression.conv3,crf_test.regression.conv3,
                   crf_test.regression.conv4,crf_test.regression.conv4]
    h5 = h5py.File(filename,'r')
    list(h5.keys())
    model_weights = h5['model_weights'] 
    optimizer_weights = h5['optimizer_weights']  
    def get_dataset_keys(f):
        keys = []
        f.visit(lambda key : keys.append(key) if type(f[key]) is h5py._hl.dataset.Dataset else None)
        return keys
    weight_keys = get_dataset_keys(model_weights)
    for key, model in zip(weight_keys, model_order):
        tensor_data = torch.tensor(np.transpose(model_weights[key][()]))
        #print(key,tensor_data.shape)
        para = nn.Parameter(tensor_data, requires_grad=True)
        if 'bias' in key:
            model.bias = para
        if 'kernel' in key:
            model.weight= para
train_dataset = nuscenes_dataset(nusc,image_min_side=360,image_max_side=640)
for i in range(1):
    fused_image, annotations = train_dataset[i]
    #print("Image shape: ",fused_image.shape, 'Annotation shape:',annotations.size())
    test_img = fused_image
    test_img.unsqueeze_(0)
    annotations.unsqueeze_(0)
    
    if not annotations.shape[1]==0:
        regression, classification, distance, batch_anchors = crf_test.forward(test_img)
        #print("Regression output:", crf_output[0].size(),
        #      "Classification output:",crf_output[1].size(),
        #      "Distance output:",crf_output[2].size(),
        #     "Anchors:", crf_output[3].size())
        #print("Annotation", annotations)

        # loss1 test
        loss = RetinaLoss(image_w=640,image_h=360)
        cls_loss, reg_loss = loss(classification, regression, batch_anchors, annotations)
        print(cls_loss, reg_loss)

        # Decoder test
        decoder = RetinaDecoder(image_w=640,image_h=360) 
        batch_scores, batch_classes, batch_pred_bboxes, predicted_box = decoder(classification, regression, batch_anchors)
        print("Image index:", i,"Number of detected box:", len(batch_scores[batch_scores>0]), "Number of annotation", len(annotations[0,:,0]))
        #print("Ground truth:",'\n', annotations)
        #print("Prediction:",'\n', predicted_box)
        

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [None]:
# visualize
def visualize_predictions(predictions, image_data_vis, dist=False, verbose=False):
    """
    Visualizes the predictions as bounding boxes with distances or confidence score in a given image.

    :param predictions:         <list>              List with [bboxes, probs, labels]
    :param image_data_vis:      <np.array>          Image where the predictions should be visualized
    :param generator:           <Generator>         Data generator used for name to label mapping
    :dist:                      <bool>              True if distance detection is enabled
    :verbose:                   <bool>              True if detetions should be printed 

    """
    
    font                   = cv2.FONT_HERSHEY_SIMPLEX
    fontScale              = 0.4

    # Visualization prediction
    all_dets = []
    bboxes = predictions
    print(bboxes.shape)
    for jk in range(bboxes.shape[0]):
        (x1, y1, x2, y2) = bboxes[jk,:]
        
   
        cv2.rectangle(image_data_vis,(x1, y1), (x2, y2),2)

image = cv2.imread('F:\\CameraRadarFusionNet\\crfnet\\data\\nuscenes\\samples\\CAM_FRONT\\n015-2018-07-24-11-22-45+0800__CAM_FRONT__1532402928112460.jpg')
image = cv2.resize(image, (640,360))
visualize_predictions(predicted_box[:,0:4],image)
#visualize_predictions(annotations[0,:,0:4],image)
cv2.imshow('image',image)
cv2.waitKey(0)

In [None]:
#backbone test
train_dataset = nuscenes_dataset(nusc,image_min_side= 360,image_max_side = 640)
train_loader = DataLoader(dataset=train_dataset, batch_size=1, shuffle=True)
fused_image, annotations = train_dataset[i]
test_img = fused_image
test_img.unsqueeze_(0)

backbone_test = Backbone()
backbone_output =  backbone_test.forward(test_img)
#print(test_img)
C3 = backbone_output['vgg_output_3'] 
C4 = backbone_output['vgg_output_4']
C5 = backbone_output['vgg_output_5']
R3 = backbone_output['rad_output_3']
R4 = backbone_output['rad_output_4']
R5 = backbone_output['rad_output_5']
R6 = backbone_output['rad_output_6']
R7 = backbone_output['rad_output_7']
print(C3.shape,C4.shape,C5.shape)
print(R3.shape,R4.shape,R5.shape,R6.shape,R7.shape)

In [None]:
#FPN test
fpn_test = PyramidFeatures()
#print(fpn_test)
fpn_output = fpn_test.forward(backbone_output)
P3,P4,P5,P6,P7 = fpn_output
print(P3.shape,P4.shape,P5.shape,P6.shape,P7.shape)

#Detection head test
regress_test= RegressionModel()
class_test= ClassificationModel()
regress_output=regress_test.forward(P3)
print(regress_output.shape)
class_output=class_test.forward(P3)
print(class_output.shape)

## Training on the dataset 

In [13]:
# training stage
max_epochs = 20
batch_size = 1
num_workers = 4

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("device:", device)
# crf test
crf_test = CRFNet().to(device)
# pretrain weights
pretrained = True
if pretrained == True:
    # Load pretrained weight for model 
    import h5py

    filename = "crf_net.h5"
    model_order = [crf_test.fpn.P3_1, crf_test.fpn.P3_1,crf_test.fpn.P4_1, crf_test.fpn.P4_1,crf_test.fpn.P5_1, crf_test.fpn.P5_1,
                   crf_test.fpn.P3_2, crf_test.fpn.P3_2,crf_test.fpn.P4_2, crf_test.fpn.P4_2,crf_test.fpn.P5_2, crf_test.fpn.P5_2,
                   crf_test.fpn.P6, crf_test.fpn.P6,crf_test.fpn.P7_2, crf_test.fpn.P7_2, 
                   crf_test.backbone.blocks[0][0],crf_test.backbone.blocks[0][0],crf_test.backbone.blocks[0][2],crf_test.backbone.blocks[0][2],
                   crf_test.backbone.blocks[1][0],crf_test.backbone.blocks[1][0],crf_test.backbone.blocks[1][2],crf_test.backbone.blocks[1][2],
                   crf_test.backbone.blocks[2][0],crf_test.backbone.blocks[2][0],crf_test.backbone.blocks[2][2],crf_test.backbone.blocks[2][2],crf_test.backbone.blocks[2][4],crf_test.backbone.blocks[2][4],
                   crf_test.backbone.blocks[3][0],crf_test.backbone.blocks[3][0],crf_test.backbone.blocks[3][2],crf_test.backbone.blocks[3][2],crf_test.backbone.blocks[3][4],crf_test.backbone.blocks[3][4],
                   crf_test.backbone.blocks[4][0],crf_test.backbone.blocks[4][0],crf_test.backbone.blocks[4][2],crf_test.backbone.blocks[4][2],crf_test.backbone.blocks[4][4],crf_test.backbone.blocks[4][4],
                   crf_test.classification.output,crf_test.classification.output,
                   crf_test.classification.conv1,crf_test.classification.conv1,
                   crf_test.classification.conv2,crf_test.classification.conv2,
                   crf_test.classification.conv3,crf_test.classification.conv3,
                   crf_test.classification.conv4,crf_test.classification.conv4,
                   crf_test.regression.output,crf_test.regression.output,
                   crf_test.regression.conv1,crf_test.regression.conv1,
                   crf_test.regression.conv2,crf_test.regression.conv2,
                   crf_test.regression.conv3,crf_test.regression.conv3,
                   crf_test.regression.conv4,crf_test.regression.conv4]
    h5 = h5py.File(filename,'r')
    list(h5.keys())
    model_weights = h5['model_weights'] 
    optimizer_weights = h5['optimizer_weights']  
    def get_dataset_keys(f):
        keys = []
        f.visit(lambda key : keys.append(key) if type(f[key]) is h5py._hl.dataset.Dataset else None)
        return keys
    weight_keys = get_dataset_keys(model_weights)
    for key, model in zip(weight_keys, model_order):
        #if 'block' in key:
        tensor_data = torch.tensor(np.transpose(model_weights[key][()])).to(device)
        #print(key,tensor_data.shape)
        para = nn.Parameter(tensor_data, requires_grad=True)
        if 'bias' in key:
            model.bias = para
        if 'kernel' in key:
            model.weight= para

# loss test
criterion = RetinaLoss(image_w=360,image_h=640).to(device)

torch.backends.cudnn.benchmark = True

# Optimizer:
optimizer = torch.optim.Adam(crf_test.parameters(), lr=2*(10**-5))

training_set = nuscenes_dataset(nusc,image_min_side= 360,image_max_side = 640)
training_generator = DataLoader(dataset=training_set, batch_size=batch_size, shuffle=False)

# Decoder
decoder = RetinaDecoder(image_w=640,image_h=360)
num_categories = 8 #number of categories that can be detected
iou_range = np.linspace(0.5, 0.95, 10)


for epoch in range(max_epochs):
    index = 0
    loss_sum = 0
    detections = [] #List of detections, with size [N * [M, D]]
    annotations = []     #List of annots, with size [N * [P, D]]
    print("Epoch {}/{}, start".format(epoch+1, max_epochs))
    for img_plus, annots in training_generator:

            print(index, ":")
            if not annots.shape[1]==0:
                
                img_plus = img_plus.to(device)
                annots = annots.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()
                #print(img_plus.shape)

                output = crf_test(img_plus)
                #print("output calculated")

                #print(annots)
                cls_loss, reg_loss = criterion(output[1], output[0], output[3], annots) 
                loss = cls_loss + reg_loss
                print('cls_loss:',"%.3f" % cls_loss.item(),'reg_loss:', "%.3f" % reg_loss.item(),'total loss:', "%.3f" % loss.item())

                loss_sum += loss.item()
                if not cls_loss.grad_fn == None:
                    loss.backward()
                    optimizer.step()
                    index += 1
                    _, _, _, predicted_box = decoder(output[1], output[0], output[3])
                    detections.append(predicted_box) #List of detections, with size [N * [M, D]]
                    annotations.append(annots.squeeze())  #List of annots, with size [N * [P, D]]
                if index==10:
                    break
    
    mAP = mAP_calc(detections, annotations, iou_range, num_categories)
    #print epoch and corresponding loss
    print("Epoch {}/{}, Average Loss: {:.3f}, mAP: {:.3f}".format(epoch+1, max_epochs, loss_sum/index, mAP))

device: cuda:0
Epoch 1/20, start
0 :
cls_loss: 1.876 reg_loss: 0.963 total loss: 2.839
1 :
cls_loss: 1.211 reg_loss: 0.881 total loss: 2.092
2 :
cls_loss: 1.344 reg_loss: 0.829 total loss: 2.174
3 :
cls_loss: 1.144 reg_loss: 0.895 total loss: 2.040
4 :
cls_loss: 1.244 reg_loss: 0.881 total loss: 2.125
5 :
cls_loss: 1.011 reg_loss: 0.858 total loss: 1.869
6 :
cls_loss: 1.197 reg_loss: 0.931 total loss: 2.128
7 :
cls_loss: 1.366 reg_loss: 0.836 total loss: 2.202
8 :
cls_loss: 1.467 reg_loss: 0.791 total loss: 2.258
9 :
cls_loss: 1.498 reg_loss: 0.877 total loss: 2.375
Epoch 1/20, Average Loss: 2.210, mAP: 0.001
Epoch 2/20, start
0 :
cls_loss: 1.303 reg_loss: 0.792 total loss: 2.094
1 :
cls_loss: 0.779 reg_loss: 0.751 total loss: 1.530
2 :
cls_loss: 1.009 reg_loss: 0.687 total loss: 1.696
3 :
cls_loss: 0.780 reg_loss: 0.731 total loss: 1.510
4 :
cls_loss: 0.861 reg_loss: 0.750 total loss: 1.610
5 :
cls_loss: 0.774 reg_loss: 0.733 total loss: 1.507
6 :
cls_loss: 0.967 reg_loss: 0.819 total

7 :
cls_loss: 0.193 reg_loss: 0.287 total loss: 0.480
8 :
cls_loss: 0.347 reg_loss: 0.268 total loss: 0.615
9 :
cls_loss: 0.412 reg_loss: 0.349 total loss: 0.761
Epoch 14/20, Average Loss: 0.477, mAP: 0.593
Epoch 15/20, start
0 :
cls_loss: 0.399 reg_loss: 0.196 total loss: 0.595
1 :
cls_loss: 0.427 reg_loss: 0.304 total loss: 0.731
2 :
cls_loss: 0.202 reg_loss: 0.186 total loss: 0.388
3 :
cls_loss: 0.087 reg_loss: 0.140 total loss: 0.227
4 :
cls_loss: 0.064 reg_loss: 0.177 total loss: 0.241
5 :
cls_loss: 0.059 reg_loss: 0.196 total loss: 0.255
6 :
cls_loss: 0.211 reg_loss: 0.339 total loss: 0.550
7 :
cls_loss: 0.205 reg_loss: 0.305 total loss: 0.510
8 :
cls_loss: 0.352 reg_loss: 0.296 total loss: 0.647
9 :
cls_loss: 0.450 reg_loss: 0.342 total loss: 0.793
Epoch 15/20, Average Loss: 0.494, mAP: 0.551
Epoch 16/20, start
0 :
cls_loss: 0.345 reg_loss: 0.187 total loss: 0.532
1 :
cls_loss: 0.276 reg_loss: 0.183 total loss: 0.459
2 :
cls_loss: 0.190 reg_loss: 0.132 total loss: 0.322
3 :
cls_