# Final competition of Deep Learning 2020 Spring
Traffic environment semi-supervised Learning Contest

## Goals
The objective is to train a model using images captured by six different cameras attached to the same car to generate a top down view of the surrounding area. The performance of the model will be evaluated by (1) the ability of detecting objects (like car, trucks, bicycles, etc.) and (2) the ability to draw the road map layout.

## Data
You will be given two sets of data:

 1. Unlabeled set: just images
 2. Labeled set: images and the labels(bounding box and road map layout)

This notebook will help you understand the dataset.

In [1]:
!nvidia-smi

Sat May  2 15:37:39 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.39       Driver Version: 418.39       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  On   | 00000000:85:00.0 Off |                    0 |
| N/A   28C    P0    39W / 250W |      0MiB / 16280MiB |      0%   E. Process |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [2]:
import os
import time
import sys
import random
import psutil

import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box
from sklearn.metrics import confusion_matrix

random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '../../DLSP20Dataset/data'
annotation_csv = '../../DLSP20Dataset/data/annotation.csv'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.cuda.is_available()

True

In [3]:
# function to count number of parameters
def get_n_params(model):
    np=0
    for p in list(model.parameters()):
        np += p.nelement()
    return np

def order_points(pts):
    from scipy.spatial import distance as dist
    import numpy as np
    
    xSorted = pts[np.argsort(pts[:, 0]), :]

    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]

    leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    (tl, bl) = leftMost

    D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0]
    (br, tr) = rightMost[np.argsort(D)[::-1], :]

    return np.array([tl, tr, br, bl], dtype="float32")

def arrange_box(x1,y1):
    box=np.array(list(zip(x1,y1)))
    box=order_points(box)
    return box

def iou(box1, box2):
    from shapely.geometry import Polygon
    a = Polygon(torch.t(box1)).convex_hull
    b = Polygon(torch.t(box2)).convex_hull
    
    return a.intersection(b).area / a.union(b).area

#def iou(xy1,xy2):
#    
#    from shapely.geometry import Polygon
#    
#    boxA = Polygon(arrange_box(xy1[0],xy1[1])).buffer(1e-9)
#    boxB = Polygon(arrange_box(xy2[0],xy2[1])).buffer(1e-9)
#    
#    try:
#        return boxA.intersection(boxB).area / boxA.union(boxB).area
#    except:
#        print('Box 1:',xy1[0],xy1[1])
#        print('Box 2:',xy2[0],xy2[1])
#        sys.exit(1)

def map_to_ground_truth(overlaps, print_it=False):
    prior_overlap, prior_idx = overlaps.max(1)
    if print_it: print(prior_overlap)
#     pdb.set_trace()
    gt_overlap, gt_idx = overlaps.max(0)
    gt_overlap[prior_idx] = 1.99
    for i,o in enumerate(prior_idx): gt_idx[o] = i
    return gt_overlap,gt_idx

def calculate_overlap(target_bb, predicted_bb):
    overlaps = torch.zeros(target_bb.size(0),predicted_bb.size(0))

    for j in range(overlaps.shape[0]):
        for k in range(overlaps.shape[1]):
            overlaps[j][k] = iou(target_bb[j],predicted_bb[k])
            
    return overlaps

def one_hot_embedding(labels, num_classes):
    return torch.eye(num_classes)[labels.data.cpu()]

from skimage import draw
import numpy as np

def poly2mask(vertex_row_coords, vertex_col_coords, shape):
    fill_row_coords, fill_col_coords = draw.polygon(vertex_row_coords, vertex_col_coords, shape)
    mask = torch.zeros(shape, dtype=np.bool)
    mask[fill_row_coords, fill_col_coords] = True
    return mask

def convert_to_binary_mask(corners, shape=(800,800)):
    point_squence = torch.stack([corners[:, 0], corners[:, 1], corners[:, 3], corners[:, 2], corners[:, 0]])
    x,y = point_squence.T[0].detach() * 10 + 400, -point_squence.T[1].detach() * 10 + 400
    new_im = poly2mask(y, x, shape)
    return new_im

def create_conf_matrix(target, pred, debug=True):
    import sys
    
    target = target.reshape(-1)
    pred = pred.reshape(-1)
    
    if debug:
        print('Target values:', target.unique())
        print('Predicted values:', pred.unique())
        print('Target shape:', target.shape)
        print('Predicted shape:', pred.shape)
    
    nb_classes = max(target.unique())
    if len(pred.unique()) > (nb_classes+1) :
        print('More predicted classes than true classes')
        sys.exit(1)
        
    conf_matrix = torch.zeros(nb_classes+1, nb_classes+1)
    for t, p in zip(target, pred):
        conf_matrix[t, p] += 1
    
    return conf_matrix

def create_conf_matrix2(target, pred, debug=True):
    import sys
    
    target = target.reshape(-1).cpu().numpy()
    pred = pred.reshape(-1).cpu().numpy()
    
        
    conf_matrix = torch.from_numpy(confusion_matrix(target, pred)).to(device)
    
    print('Threat Score: {}'.format((1.0*conf_matrix[1,1])/(conf_matrix[1,1]+conf_matrix[1,0]+conf_matrix[0,1])))
    
    return conf_matrix

def classScores(conf_matrix):
    print('Confusion matrix\n', conf_matrix)
    TP = conf_matrix.diag()
    TN = torch.zeros_like(TP)
    FP = torch.zeros_like(TP)
    FN = torch.zeros_like(TP)
    for c in range(conf_matrix.size(0)):
        idx = torch.ones(conf_matrix.size(0)).byte()
        idx[c] = 0
        # all non-class samples classified as non-class
        TN[c] = conf_matrix[idx.nonzero()[:, None], idx.nonzero()].sum() #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
        # all non-class samples classified as class
        FP[c] = conf_matrix[idx, c].sum()
        # all class samples not classified as class
        FN[c] = conf_matrix[c, idx].sum()

        print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(
            c, TP[c], TN[c], FP[c], FN[c]))
        
    return TP, TN, FP, FN

def split_list(a_list):
    half = len(a_list)//2
    return a_list[:half], a_list[half:]

# Dataset

You will get two different datasets:

 1. an unlabeled dataset for pre-training
 2. a labeled dataset for both training and validation
 
## The dataset is organized into three levels: scene, sample and image

 1. A scene is 25 seconds of a car's journey.
 2. A sample is a snapshot of a scene at a given timeframe. Each scene will be divided into 126 samples, so about 0.2 seconds between consecutive samples.
 3. Each sample contains 6 images captured by camera facing different orientation.
    Each camera will capture 70 degree view. To make it simple, you can safely assume that the angle between the cameras is 60 degrees 

106 scenes in the unlabeled dataset and 28 scenes in the labeled dataset

In [4]:
# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.arange(106, 134)

train_scene_index = np.random.choice(labeled_scene_index, int(np.ceil(0.8*len(labeled_scene_index))))

test_scene_index = labeled_scene_index[np.isin(labeled_scene_index, train_scene_index,invert=True)]

# Unlabeled dataset

You get two ways to access the dataset, by sample or by image

## Get Sample

In [5]:
#transform = torchvision.transforms.ToTensor()

transform=torchvision.transforms.Compose([torchvision.transforms.Resize((256,256)),
                                          torchvision.transforms.ToTensor(),
                              torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])

unlabeled_trainset = UnlabeledDataset(image_folder=image_folder, scene_index=unlabeled_scene_index, first_dim='sample', transform=transform)
trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=16, shuffle=True, num_workers=2)

In [6]:
# [batch_size, 6(images per sample), 3, H, W]
sample = iter(trainloader).next()
print(sample.shape)

torch.Size([16, 6, 3, 256, 256])


In [7]:
sample[0][1].shape

torch.Size([3, 256, 256])

In [8]:
torchvision.utils.make_grid(sample[2], nrow=3).numpy().transpose(1, 2, 0).shape

(518, 776, 3)

## Get individual image

In [9]:
unlabeled_trainset = UnlabeledDataset(image_folder=image_folder, scene_index=unlabeled_scene_index, first_dim='image', transform=transform)
trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=2, shuffle=True, num_workers=2)

In [10]:
# [batch_size, 3, H, W]
image, camera_index = iter(trainloader).next()
print(image.shape)

torch.Size([2, 3, 256, 256])


In [11]:
# Camera_index is to tell you which camera is used. The order is
# CAM_FRONT_LEFT, CAM_FRONT, CAM_FRONT_RIGHT, CAM_BACK_LEFT, CAM_BACK, CAM_BACK_RIGHT
print(camera_index[0])

tensor(2)


# Labeled dataset

In [12]:
batch_size = 8

In [13]:
# The labeled dataset can only be retrieved by sample.
# And all the returned data are tuple of tensors, since bounding boxes may have different size
# You can choose whether the loader returns the extra_info. It is optional. You don't have to use it.
labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_scene_index,
                                  transform=transform,
                                  extra_info=True
                                 )

trainloader = torch.utils.data.DataLoader(labeled_trainset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=2, 
                                          collate_fn=collate_fn)

labeled_testset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=test_scene_index,
                                  transform=transform,
                                  extra_info=True
                                 )

testloader = torch.utils.data.DataLoader(labeled_testset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=2, 
                                          collate_fn=collate_fn)

In [14]:
sample, target, road_image, extra = iter(trainloader).next()

There are two kind of labels

 1. The bounding box of surrounding objects
 2. The binary road_image

In [15]:
len(sample)

8

In [16]:
n_feature = 20

In [17]:
torch.stack(road_image).shape

torch.Size([8, 800, 800])

In [18]:
class ConvLayer(nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size=3, 
                 stride=1, 
                 padding=0, 
                 bias = True, 
                 pool=False,
                 mp_kernel_size=2, 
                 mp_stride=2):
        super(ConvLayer, self).__init__()
        if pool:
            self.layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.Dropout(0.5),
                nn.LeakyReLU(negative_slope=0.1), ## nn.ReLU(), 
                nn.MaxPool2d(kernel_size=mp_kernel_size, stride=mp_stride))
        else:
            self.layer = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
                nn.BatchNorm2d(out_channels),
                nn.Dropout(0.5),
                nn.LeakyReLU(negative_slope=0.1), ## nn.ReLU(), 
                )
        
    def forward(self, x):
        return self.layer(x)

class LinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(LinearLayer, self).__init__()
        self.layer = nn.Sequential(
            torch.nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.Dropout(0.5),
            nn.LeakyReLU(negative_slope=0.1) ## nn.ReLU()
        )
        
    def forward(self, x):
        return self.layer(x)

class ConvTLayer(nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size=3, 
                 stride=1, 
                 padding=0, 
                 output_padding=0, 
                 unpool=False,
                 mp_kernel_size=2, 
                 mp_stride=2):
        super(ConvTLayer, self).__init__()
        if unpool:
            self.layer = nn.Sequential(
                nn.ConvTranspose2d(in_channels, 
                                   out_channels, 
                                   kernel_size, 
                                   stride=stride, 
                                   padding=padding, 
                                   output_padding=output_padding, 
                                   bias=False),
                nn.BatchNorm2d(out_channels),
                nn.Dropout(0.5),
                nn.LeakyReLU(negative_slope=0.1), ## nn.ReLU()
                nn.MaxUnpool2d(kernel_size=mp_kernel_size, stride=mp_stride)
            )
        else:
            self.layer = nn.Sequential(
                nn.ConvTranspose2d(in_channels, 
                                   out_channels, 
                                   kernel_size, 
                                   stride=stride, 
                                   padding=padding, 
                                   output_padding=output_padding, 
                                   bias=False),
                nn.BatchNorm2d(out_channels),
                nn.Dropout(0.5),
                nn.LeakyReLU(negative_slope=0.1), ## nn.ReLU()
            )        
    def forward(self, x):
        return self.layer(x)

class Encoder1(nn.Module):
    def __init__(self):
        super(Encoder1, self).__init__()
        self.conv1 = ConvLayer(3,96, stride=2)
        self.conv2 = ConvLayer(96,128, stride=2)
        self.conv3 = ConvLayer(128,256, stride=2)
        self.conv4 = ConvLayer(256,512, stride=2)
        self.conv5 = ConvLayer(512,1024, padding=(0,0))
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x
    
class EncoderY(nn.Module):
    def __init__(self):
        super(Encoder1, self).__init__()
        self.conv1 = ConvLayer(3,96, stride=2)
        self.conv2 = ConvLayer(96,128, stride=2)
        self.conv3 = ConvLayer(128,256, stride=2)
        self.conv4 = ConvLayer(256,512, stride=2)
        self.conv5 = ConvLayer(512,1024, padding=(0,0))
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.encoder = Encoder()
        self.linear = nn.Linear(1024*7*7,1)
        
    def forward(self,x):
        x = self.encoder(x)
        x = self.linear(x.reshape(-1,1024*7*7))
        return torch.sigmoid(x)

class Discriminator1(nn.Module):
    def __init__(self):
        super(Discriminator1, self).__init__()
        self.encoder = Encoder1()
        self.linear = nn.Linear(1024*13*13,1)
        
    def forward(self,x):
        x = self.encoder(x)
        x = self.linear(x.reshape(-1,1024*13*13))
        return torch.sigmoid(x)

def random_vector(batch_size, length):
    # Sample from a Gaussian distribution
    z_vec = torch.randn(batch_size, length, 1, 1).float()
    if torch.cuda.is_available():
        z_vec = z_vec.to(device)
    return z_vec

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.convt1 = ConvTLayer(4096, 2048, stride=2)
        self.convt2 = ConvTLayer(2048, 1024, stride=2, output_padding=(0,0))
        self.convt3 = ConvTLayer(1024, 512, stride=2, padding=(1,1), output_padding=(1,1))
        self.convt4 = ConvTLayer(512, 256, stride=2, output_padding=(1,1))
        self.convt5 = ConvTLayer(256, 128, stride=2, output_padding=(1,1))
        self.convt6 = ConvTLayer(128, 96, stride=2, output_padding=(1,1))
        self.convt7 = ConvTLayer(96, 64, stride=2, output_padding=(1,1))
        self.convt8 = ConvTLayer(64, 32, stride=1, output_padding=(0,0))
        self.convt9 = ConvTLayer(32, 18, stride=1, padding=(1,1), output_padding=(0,0))
        
    def forward(self,z):
        z = self.convt1(z)
        z = self.convt2(z)
        z = self.convt3(z)
        z = self.convt4(z)
        z = self.convt5(z)
        z = self.convt6(z)
        z = self.convt7(z)
        z = self.convt8(z)
        z = self.convt9(z)
        return z

class Decoder1(nn.Module):
    def __init__(self):
        super(Decoder1, self).__init__()
        self.convt1 = ConvTLayer(4096, 2048, stride=2)
        self.convt2 = ConvTLayer(2048, 1024, stride=2, output_padding=(1,1))
        self.convt3 = ConvTLayer(1024, 512, stride=2, padding=(1,1), output_padding=(0,0))
        self.convt4 = ConvTLayer(512, 256, stride=2, output_padding=(0,0))
        self.convt5 = ConvTLayer(256, 128, stride=2, output_padding=(0,0))
        self.convt6 = ConvTLayer(128, 96, stride=2, output_padding=(0,0))
        self.convt7 = ConvTLayer(96, 3, stride=2, output_padding=(1,1))
        
    def forward(self,z):
        z = self.convt1(z)
        z = self.convt2(z)
        z = self.convt3(z)
        z = self.convt4(z)
        z = self.convt5(z)
        z = self.convt6(z)
        z = self.convt7(z)
        return z

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.decoder = Decoder()
        
    def forward(self,x):
        x = self.decoder(x)
        return torch.tanh(x).reshape(6,-1,3,256,256)

class Generator1(nn.Module):
    def __init__(self):
        super(Generator1, self).__init__()
        self.decoder = Decoder1()
        
    def forward(self,x):
        x = self.decoder(x)
        return torch.tanh(x)

In [19]:
!nvidia-smi

Sat May  2 15:37:44 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.39       Driver Version: 418.39       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  On   | 00000000:85:00.0 Off |                    0 |
| N/A   28C    P0    39W / 250W |     10MiB / 16280MiB |      0%   E. Process |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [20]:
class EncoderY(nn.Module):
    def __init__(self,  d):
        super(EncoderY, self).__init__()
        self.conv1 = ConvLayer(3,96, stride=2)
        self.conv2 = ConvLayer(96,128, stride=2)
        self.conv3 = ConvLayer(128,256, stride=2)
        self.conv4 = ConvLayer(256,512, stride=2)
        self.conv5 = ConvLayer(512,1024, stride=2)
        self.conv6 = ConvLayer(1024,2048, stride=2)
        self.lin1 = nn.Linear(2048*3*3, d)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        #print(x.shape)
        x = self.lin1(x.reshape(-1,2048*3*3))
        return x

In [21]:
class EncoderX(nn.Module):
    def __init__(self, d):
        super(EncoderX, self).__init__()
        self.conv1 = ConvLayer(2,16, stride=2)
        self.conv2 = ConvLayer(16,32, stride=2)
        self.conv3 = ConvLayer(32,48, stride=2)
        self.conv4 = ConvLayer(48,64, stride=2)
        self.conv5 = ConvLayer(64,96, stride=2)
        self.conv6 = ConvLayer(96,128, stride=2)
        self.conv7 = ConvLayer(128,256, stride=2)
        self.conv8 = ConvLayer(256,512, stride=2)
        self.lin1 = nn.Linear(512*2*2, d)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        #print(x.shape)
        x = self.lin1(x.reshape(-1,512*2*2))
        return x

In [22]:
class DecoderX(nn.Module):
    def __init__(self):
        super(DecoderX, self).__init__()
        self.convt1 = ConvTLayer(4096, 2048, kernel_size=3, stride=2)
        self.convt2 = ConvTLayer(2048, 1024, kernel_size=3, stride=3, output_padding=(0,0))
        self.convt3 = ConvTLayer(1024, 512, kernel_size=3, stride=2, padding=(1,1), output_padding=(0,0))
        self.convt4 = ConvTLayer(512, 256, kernel_size=3, stride=3, padding=(1,1), output_padding=(0,0))
        self.convt5 = ConvTLayer(256, 128, kernel_size=3, stride=2, output_padding=(0,0))
        self.convt6 = ConvTLayer(128, 96, kernel_size=3, stride=2, output_padding=(0,0))
        self.convt7 = ConvTLayer(96, 64, kernel_size=3, stride=2, output_padding=(0,0))
        self.convt8 = ConvTLayer(64, 2, kernel_size=3, stride=2, output_padding=(1,1))
        
    def forward(self,z):
        z = self.convt1(z)
        z = self.convt2(z)
        z = self.convt3(z)
        z = self.convt4(z)
        z = self.convt5(z)
        z = self.convt6(z)
        z = self.convt7(z)
        z = self.convt8(z)
        return torch.sigmoid(z)

In [23]:
# Defining the model

class CNN_VAE(nn.Module):
    def __init__(self, hidden_d=286, image_d=635): #hidden_d=196, image_d=650 or hidden_d=286, image_d=635
        super().__init__()
        
        self.d = hidden_d
        self.id = image_d
        
        self.y_encoder = EncoderY(d=self.id)

        self.x_encoder = EncoderX(d=2*self.d)

        self.x_decoder = DecoderX()

    def reparameterise(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, y):
        mu_logvar = self.x_encoder(x).view(-1, 2, self.d)
        #print(mu_logvar.shape)
        img_enc = [self.y_encoder(img.squeeze()) for img in y] 
        mu = mu_logvar[:, 0, :]
        #print(mu.shape)
        logvar = mu_logvar[:, 1, :]
        #print(logvar.shape)
        z = self.reparameterise(mu, logvar)
        img_enc.append(z)
        out = torch.cat(img_enc,axis=1).reshape(-1,4096,1,1)
        return self.x_decoder(out), mu, logvar
    
    def inference(self, y, mu=None, logvar=None):
        N = y.size(1)
        z = torch.randn((N, self.d)).to(device)
        #print('Prior:',z.shape)
        if mu is not None and logvar is not None:
            #print(mu.shape)
            #print(logvar.shape)
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            z = eps.mul(std).add_(mu)
            #print('Post:',z.shape)
        z = z.reshape(-1,self.d)
        img_enc = [self.y_encoder(img.squeeze()) for img in y] 
        img_enc.append(z)
        out = torch.cat(img_enc,axis=1).reshape(-1,4096,1,1)
        return self.x_decoder(out)
    

In [24]:
    
model = CNN_VAE().to(device)
# Setting the optimiser

learning_rate = 1e-3

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)

In [25]:
# Reconstruction + KL divergence losses summed over all elements and batch

def loss_function(x_hat, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(
        x_hat, x, reduction='sum'
    )
    KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))

    return BCE + KLD

In [26]:
40*0.005

0.2

In [27]:
hidden_dim = 4096 - 6*635
hidden_dim

286

In [28]:
# Training and testing the VAE

accuracy_list = []
threshold = 0.5
epochs = 10
codes = dict(μ=list(), logσ2=list(), y=list())
for epoch in range(0, epochs + 1):
    # Training
    if epoch >= 0:  # test untrained net first
        model.train()
        train_loss = 0
        for i, data in enumerate(trainloader):
            sample, target, road_image, extra = data
            batch_size = len(road_image)
            x = torch.zeros((batch_size,2,800,800))
            x[:,0,:,:] = 1.0*torch.stack(road_image).reshape(-1, 800, 800)
            x[:,1,:,:] = torch.zeros((batch_size,800,800))
            for i in range(batch_size):
                for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
                    x[i,1,:,:] = 1.0*convert_to_binary_mask(bb)
            x = x.to(device)       
            y = torch.stack(sample).reshape(6,-1,3,256,256).to(device)
            # ===================forward=====================
            x_hat, mu, logvar = model(x, y)
            loss = loss_function(x_hat, x, mu, logvar)
            train_loss += loss.item()
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # ===================log========================
        print(f'====> Epoch: {epoch} Average loss: {train_loss / len(trainloader.dataset):.4f}')

    means, logvars, labels = list(), list(), list()
    with torch.no_grad():
        model.eval()
        test_loss_post = 0
        test_loss_prior = 0
        
        road_correct_post = 0
        road_correct_prior = 0
        total_road = 0
        
        bb_correct_post = 0
        bb_correct_prior = 0
        total_bb = 0
        
        conf_matrix = [torch.zeros(2,2).to(device),torch.zeros(2,2).to(device)]
        for batch_idx, data in enumerate(testloader):
            sample, target, road_image, extra = data
            batch_size = len(road_image)
            x = torch.zeros((batch_size,2,800,800))
            x[:,0,:,:] = 1.0*torch.stack(road_image).reshape(-1, 800, 800)
            x[:,1,:,:] = torch.zeros((batch_size,800,800))
            for i in range(batch_size):
                for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
                    x[i,1,:,:] = 1.0*convert_to_binary_mask(bb)
            x = x.to(device)
            y = torch.stack(sample).reshape(6,-1,3,256,256).to(device)
            
            # ===================forward=====================
            mu = torch.mean(mu,0).repeat(batch_size).view(batch_size,hidden_dim)
            logvar = torch.mean(logvar,0).repeat(batch_size).view(batch_size,hidden_dim)
            x_hat_post = model.inference(y, mu, logvar)
            x_hat_prior = model.inference(y)
            test_loss_post += loss_function(x_hat_post, x, mu, logvar).item()
            test_loss_prior += loss_function(x_hat_prior, x, mu, logvar).item()
            # =====================log=======================
            means.append(mu.detach())
            logvars.append(logvar.detach())

            
            for channel in [0,1]:
                if channel == 0:
                    road_correct_post += (x_hat_post[:,channel,:,:]>threshold).eq(
                    (x[:,channel,:,:]==1).data.view_as(
                        (x_hat_post[:,channel,:,:]>threshold))).cpu().sum().item()
                    road_correct_prior += (x_hat_prior[:,channel,:,:]>threshold).eq(
                    (x[:,channel,:,:]==1).data.view_as((x_hat_prior[:,channel,:,:]>threshold))).cpu().sum().item()
                    total_road += x[:,channel,:,:].nelement()
                    
                    road_accuracy_post = 100. * road_correct_post / total_road
                    road_accuracy_prior = 100. * road_correct_prior / total_road
                    
                else:   
                    bb_correct_post += (x_hat_post[:,channel,:,:]>threshold).eq(
                        (x[:,channel,:,:]==1).data.view_as((x_hat_post[:,channel,:,:]>threshold))).cpu().sum().item()
                    bb_correct_prior += (x_hat_prior[:,channel,:,:]>threshold).eq(
                        (x[:,channel,:,:]==1).data.view_as((x_hat_prior[:,channel,:,:]>threshold))).cpu().sum().item()
                    total_bb += x[:,channel,:,:].nelement()
                    
                    bb_accuracy_post = 100. * bb_correct_post / total_bb
                    bb_accuracy_prior = 100. * bb_correct_prior / total_bb
            
                print('='*100)
                print('Channel: {}'.format(channel))
                if batch_idx % 100 == 0:
                    for i in range(0,40):
                        thld = 0.48+i*0.005
                        print('Confusion Matrix (Post) at threshold: {}'.format(thld))
                        print(create_conf_matrix2(1*(x[:,channel,:,:]==1), 
                                                  1*(x_hat_post[:,channel,:,:]>thld)))
                        print('='*50)
                        print('Confusion Matrix (Prior) at threshold: {}'.format(thld))
                        print(create_conf_matrix2(1*(x[:,channel,:,:]==1), 
                                                  1*(x_hat_prior[:,channel,:,:]>thld)))
                        print('='*50)
                    print('='*75)
                    print('='*75)
                conf_matrix[channel] += create_conf_matrix2(1*(x[:,channel,:,:]==1), 
                                                            1*(x_hat_post[:,channel,:,:]>threshold))
                print('='*100)

    accuracy_list.append((road_accuracy_prior,bb_accuracy_prior))
    print("""\nTest set: Average loss: {:.4f}, 
    Accuracy Road (Post): {}/{} ({:.0f}%) ,
    Accuracy Road (Prior): {}/{} ({:.0f}%) ,
    Road: 
    TP {} 
    TN {}
    FP {}
    FN {}
    Accuracy BB (Post): {}/{} ({:.0f}%) ,
    Accuracy BB (Prior): {}/{} ({:.0f}%) ,
    BB:
    TP {} , 
    TN {}
    FP {}
    FN {}""".format(
        test_loss_post, 
        road_correct_post, total_road, road_accuracy_post, 
        road_correct_prior, total_road, road_accuracy_prior, 
        *classScores(conf_matrix_road),
        bb_correct_post, total_bb, bb_accuracy_post, 
        bb_correct_prior, total_bb, bb_accuracy_prior, 
        *classScores(conf_matrix_bb)
    ))

            #labels.append(y.detach())
    # ===================log========================
    codes['μ'].append(torch.cat(means))
    codes['logσ2'].append(torch.cat(logvars))
    #codes['y'].append(torch.cat(labels))
    test_loss_post /= len(testloader.dataset)
    test_loss_prior /= len(testloader.dataset)
    print(f'====> Posterior Test set loss: {test_loss_post:.4f}')
    print(f'====> Prior Test set loss: {test_loss_prior:.4f}')
    plt.figure(figsize=(10, 6))
    plt.subplot(1,3,1)
    plt.imshow((x[:,0,:,:][0].squeeze()==1).detach().cpu().numpy(), cmap='binary')
    plt.imshow((x[:,1,:,:][0].squeeze()==1).detach().cpu().numpy(), cmap='binary')
    plt.subplot(1,3,2)
    plt.imshow((x_hat_post[:,0,:,:][0].squeeze()>threshold).detach().cpu().numpy(), cmap='binary')
    plt.imshow((x_hat_post[:,1,:,:][0].squeeze()>threshold).detach().cpu().numpy(), cmap='binary')  
    plt.subplot(1,3,3)
    plt.imshow((x_hat_prior[:,0,:,:][0].squeeze()>threshold).detach().cpu().numpy(), cmap='binary')
    plt.imshow((x_hat_prior[:,1,:,:][0].squeeze()>threshold).detach().cpu().numpy(), cmap='binary')
    plt.show()

====> Epoch: 0 Average loss: 858537.5597
Channel: 0
Confusion Matrix (Post) at threshold: 0.48
Threat Score: 0.529310941696167
tensor([[1921245, 1140681],
        [ 364938, 1693136]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.48
Threat Score: 0.5293111205101013
tensor([[1921248, 1140678],
        [ 364939, 1693135]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.485
Threat Score: 0.6198192238807678
tensor([[2611112,  450814],
        [ 503017, 1555057]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.485
Threat Score: 0.619819700717926
tensor([[2611114,  450812],
        [ 503017, 1555057]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.49
Threat Score: 0.6379300951957703
tensor([[2733450,  328476],
        [ 535622, 1522452]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.49
Threat Score: 0.637930691242218
tensor([[2733454,  328472],
        [ 535623, 1522451]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.495
Th

Threat Score: 0.6404475569725037
tensor([[2822260,  239666],
        [ 586492, 1471582]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.58
Threat Score: 0.6404480338096619
tensor([[2822260,  239666],
        [ 586491, 1471583]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.585
Threat Score: 0.6398299932479858
tensor([[2823154,  238772],
        [ 588483, 1469591]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.585
Threat Score: 0.639830470085144
tensor([[2823154,  238772],
        [ 588482, 1469592]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.59
Threat Score: 0.6392008066177368
tensor([[2824327,  237599],
        [ 590678, 1467396]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.59
Threat Score: 0.6392012238502502
tensor([[2824327,  237599],
        [ 590677, 1467397]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.595
Threat Score: 0.6385893821716309
tensor([[2826202,  235724],
        [ 593279, 1464795]], devic

Threat Score: 0.644788384437561
Channel: 1
Confusion Matrix (Post) at threshold: 0.48
Threat Score: 0.0014475909993052483
tensor([[  10822, 5101758],
        [     24,    7396]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.48
Threat Score: 0.0014475906500592828
tensor([[  10821, 5101759],
        [     24,    7396]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.485
Threat Score: 0.001370859332382679
tensor([[  60400, 5052180],
        [    484,    6936]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.485
Threat Score: 0.0013708644546568394
tensor([[  60419, 5052161],
        [    484,    6936]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.49
Threat Score: 0.001274792361073196
tensor([[ 402369, 4710211],
        [   1406,    6014]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.49
Threat Score: 0.0012747931759804487
tensor([[ 402372, 4710208],
        [   1406,    6014]], device='cuda:0')
Confusion Matrix (Post) at threshold

Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.585
Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.59
Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.59
Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.595
Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.595
Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.6
Threat Score: 0.0
tensor([[5112388,     192],
        [   7420,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.6
Threat Score: 0.0
tensor([[5112388,     192],
    

Threat Score: 0.0
Channel: 0
Threat Score: 0.6960929036140442
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6498263478279114
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6242702603340149
Channel: 1
Threat Score: 0.00018011526844929904
Channel: 0
Threat Score: 0.6963444948196411
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6114349961280823
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5753323435783386
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6059544086456299
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7283192276954651
Channel: 1
Threat Score: 0.00014365752576850355
Channel: 0
Threat Score: 0.7274636626243591
Channel: 1
Threat Score: 0.0008036431972868741
Channel: 0
Threat Score: 0.7015412449836731
Channel: 1
Threat Score: 0.00046382189611904323
Channel: 0
Threat Score: 0.605883002281189
Channel: 1
Threat Score: 0.00013533630408346653
Channel: 0
Threat Score: 0.59597247838974
Channel: 1
Threat Score: 0.0
Channel: 0
Thre

Threat Score: 0.0002870263997465372
Channel: 0
Threat Score: 0.7414059638977051
Channel: 1
Threat Score: 0.0004573170735966414
Channel: 0
Threat Score: 0.5319705009460449
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7352715730667114
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.621837854385376
Channel: 1
Threat Score: 0.00015535186685156077
Channel: 0
Threat Score: 0.6520913243293762
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6850830316543579
Channel: 1
Threat Score: 0.00037735849036835134
Channel: 0
Threat Score: 0.7319718599319458
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5141781568527222
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6875779032707214
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5698744654655457
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5390536189079285
Channel: 1
Threat Score: 0.0001718508283374831
Channel: 0
Threat Score: 0.6814765334129333
Channel: 1
Threat Score: 0.00071204785490408

Threat Score: 0.00023934897035360336
Channel: 0
Threat Score: 0.6949109435081482
Channel: 1
Threat Score: 0.0001430819829693064
Channel: 0
Threat Score: 0.5223160982131958
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6565854549407959
Channel: 1
Threat Score: 0.00016417665756307542
Channel: 0
Threat Score: 0.6443363428115845
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7603977918624878
Channel: 1
Threat Score: 0.0002715177834033966
Channel: 0
Threat Score: 0.647404670715332
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6864015460014343
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5593398213386536
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6706070899963379
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6157615184783936
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5987209677696228
Channel: 1
Threat Score: 0.00043503480264917016
Channel: 0
Threat Score: 0.6384660601615906
Channel: 1
Threat Score: 0.0
Channel: 0
Thr

Threat Score: 0.0
Channel: 0
Threat Score: 0.6826359033584595
Channel: 1
Threat Score: 0.00021831678168382496
Channel: 0
Threat Score: 0.6715788841247559
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5166200399398804
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6236804127693176
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5606751441955566
Channel: 1
Threat Score: 0.00019681165576912463
Channel: 0
Threat Score: 0.5290414094924927
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7023520469665527
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7688657641410828
Channel: 1
Threat Score: 0.0001414227153873071
Channel: 0
Threat Score: 0.5784556269645691
Channel: 1
Threat Score: 0.00015688735584262758
Channel: 0
Threat Score: 0.732282817363739
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7563847303390503
Channel: 1
Threat Score: 0.0004179728275630623
Channel: 0
Threat Score: 0.6066539883613586
Channel: 1
Threat Score: 0.0
Channel: 0
Thr

Threat Score: 0.0
Channel: 0
Threat Score: 0.6234142184257507
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6327633857727051
Channel: 1
Threat Score: 0.00019665683794301003
Channel: 0
Threat Score: 0.5591535568237305
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.623867392539978
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.608511209487915
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6178319454193115
Channel: 1
Threat Score: 0.000179726819624193
Channel: 0
Threat Score: 0.7167844176292419
Channel: 1
Threat Score: 0.0001204529035021551
Channel: 0
Threat Score: 0.611294150352478
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6522701382637024
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5862525105476379
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6414278745651245
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.650468647480011
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6806517243385315
Channel: 1
T

Channel: 0
Threat Score: 0.6524090766906738
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7455138564109802
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6763597726821899
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5638886094093323
Channel: 1
Threat Score: 0.00031729243346489966
Channel: 0
Threat Score: 0.6673617959022522
Channel: 1
Threat Score: 0.0004329629009589553
Channel: 0
Threat Score: 0.5790317058563232
Channel: 1
Threat Score: 0.00011140819697175175
Channel: 0
Threat Score: 0.5616307854652405
Channel: 1
Threat Score: 0.00013879250036552548
Channel: 0
Threat Score: 0.6741235852241516
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6243883371353149
Channel: 1
Threat Score: 0.0003993078717030585
Channel: 0
Threat Score: 0.6239777207374573
Channel: 1
Threat Score: 0.0005780346691608429
Channel: 0
Threat Score: 0.6076785326004028
Channel: 1
Threat Score: 0.00015921032172627747
Channel: 0
Threat Score: 0.6197803616523743
Channel: 1
Threat Score

Threat Score: 0.6342443227767944
tensor([[2661460,  186431],
        [ 712794, 1559315]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.51
Threat Score: 0.634174644947052
tensor([[2663276,  184615],
        [ 714117, 1557992]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.51
Threat Score: 0.6341745257377625
tensor([[2663274,  184617],
        [ 714116, 1557993]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.515
Threat Score: 0.6340705156326294
tensor([[2665038,  182853],
        [ 715490, 1556619]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.515
Threat Score: 0.6340705156326294
tensor([[2665038,  182853],
        [ 715490, 1556619]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.52
Threat Score: 0.6339502334594727
tensor([[2666484,  181407],
        [ 716702, 1555407]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.52
Threat Score: 0.6339502334594727
tensor([[2666484,  181407],
        [ 716702, 1555407]], device

Threat Score: 0.6287357807159424
tensor([[2687094,  160797],
        [ 742454, 1529655]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.61
Threat Score: 0.6287358999252319
tensor([[2687093,  160798],
        [ 742453, 1529656]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.615
Threat Score: 0.6283963918685913
tensor([[2687911,  159980],
        [ 743793, 1528316]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.615
Threat Score: 0.6283965110778809
tensor([[2687913,  159978],
        [ 743794, 1528315]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.62
Threat Score: 0.628025233745575
tensor([[2688499,  159392],
        [ 745065, 1527044]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.62
Threat Score: 0.6280244588851929
tensor([[2688496,  159395],
        [ 745065, 1527044]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.625
Threat Score: 0.6274024844169617
tensor([[2689176,  158715],
        [ 747004, 1525105]], devic

Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.51
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.51
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.515
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.515
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.52
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.52
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.525
Threat Score: 0.0
tensor([[5113701,     192],
   

Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.62
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.62
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.625
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.625
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.63
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Prior) at threshold: 0.63
Threat Score: 0.0
tensor([[5113701,     192],
        [   6107,       0]], device='cuda:0')
Confusion Matrix (Post) at threshold: 0.635
Threat Score: 0.0
tensor([[5113701,     192],
   

Threat Score: 0.00040777490357868373
Channel: 0
Threat Score: 0.7097308039665222
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6678136587142944
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7604116797447205
Channel: 1
Threat Score: 0.00012779553071595728
Channel: 0
Threat Score: 0.6174949407577515
Channel: 1
Threat Score: 0.0001506704866187647
Channel: 0
Threat Score: 0.6366108655929565
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5951631665229797
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.633738100528717
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6484169363975525
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5639044642448425
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6982651948928833
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6419941186904907
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5666593909263611
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5518009662628174
Channel

Channel: 0
Threat Score: 0.5508935451507568
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.645287036895752
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.8186184763908386
Channel: 1
Threat Score: 0.00013819789455737919
Channel: 0
Threat Score: 0.6461601257324219
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6069929003715515
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.626806914806366
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6215731501579285
Channel: 1
Threat Score: 0.0001156336729764007
Channel: 0
Threat Score: 0.7060264348983765
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7342374324798584
Channel: 1
Threat Score: 0.0003824091691058129
Channel: 0
Threat Score: 0.6256887316703796
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7534170746803284
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6805355548858643
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6125805377960205
Channel: 1
Threat Score: 0.

Channel: 0
Threat Score: 0.5932362675666809
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6100805997848511
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7071407437324524
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7041047811508179
Channel: 1
Threat Score: 0.0004662004648707807
Channel: 0
Threat Score: 0.7179152965545654
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.49342507123947144
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6133284568786621
Channel: 1
Threat Score: 0.0003713330952450633
Channel: 0
Threat Score: 0.6970299482345581
Channel: 1
Threat Score: 0.00014821402146480978
Channel: 0
Threat Score: 0.6492627263069153
Channel: 1
Threat Score: 0.00045464877621270716
Channel: 0
Threat Score: 0.6768777370452881
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.691694974899292
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6494233012199402
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6113026738166809
Channe

Channel: 0
Threat Score: 0.6888639330863953
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6902775168418884
Channel: 1
Threat Score: 0.00033112583332695067
Channel: 0
Threat Score: 0.6381734013557434
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6727343797683716
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6340559720993042
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6522277593612671
Channel: 1
Threat Score: 0.00011248594091739506
Channel: 0
Threat Score: 0.457999587059021
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7253381013870239
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.7398738265037537
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.702374279499054
Channel: 1
Threat Score: 0.00016923337534535676
Channel: 0
Threat Score: 0.6633254289627075
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.5129944682121277
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6651936173439026
Channel: 1
Threat Score: 

Channel: 0
Threat Score: 0.7079533338546753
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6218166351318359
Channel: 1
Threat Score: 0.0003085467324126512
Channel: 0
Threat Score: 0.7344418168067932
Channel: 1
Threat Score: 0.0009883864549919963
Channel: 0
Threat Score: 0.760807454586029
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6611596345901489
Channel: 1
Threat Score: 0.00017534631479065865
Channel: 0
Threat Score: 0.7849985957145691
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6947615146636963
Channel: 1
Threat Score: 0.0003113324928563088
Channel: 0
Threat Score: 0.4873782992362976
Channel: 1
Threat Score: 0.00020644096366595477
Channel: 0
Threat Score: 0.6294384598731995
Channel: 1
Threat Score: 0.0
Channel: 0
Threat Score: 0.6499521136283875
Channel: 1
Threat Score: 0.00012495314877014607
Channel: 0
Threat Score: 0.6256510615348816
Channel: 1
Threat Score: 0.00044503781828098
Channel: 0
Threat Score: 0.7266672253608704
Channel: 1
Threat Score: 0.

NameError: name 'conf_matrix_road' is not defined

In [None]:
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F


class CDAutoEncoder(nn.Module):
    r"""
    Convolutional denoising autoencoder layer for stacked autoencoders.
    This module is automatically trained when in model.training is True.

    Args:
        input_size: The number of features in the inputa
        output_size: The number of features to output
        stride: Stride of the convolutional layers.
    """
    def __init__(self, input_size, output_size, stride):
        super(CDAutoEncoder, self).__init__()

        self.forward_pass = nn.Sequential(
            nn.Conv2d(input_size, output_size, kernel_size=2, stride=stride, padding=0),
            nn.ReLU(),
        )
        self.backward_pass = nn.Sequential(
            nn.ConvTranspose2d(output_size, input_size, kernel_size=2, stride=2, padding=0), 
            nn.ReLU(),
        )

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1)

    def forward(self, x):
        # Train each autoencoder individually
        x = x.detach()
        # Add noise, but use the original lossless input as the target.
        x_noisy = x * (Variable(x.data.new(x.size()).normal_(0, 0.1)) > -.1).type_as(x)
        y = self.forward_pass(x_noisy)

        if self.training:
            x_reconstruct = self.backward_pass(y)
            loss = self.criterion(x_reconstruct, Variable(x.data, requires_grad=False))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        return y.detach()

    def reconstruct(self, x):
        return self.backward_pass(x)


class StackedAutoEncoder(nn.Module):
    """
    A stacked autoencoder made from the convolutional denoising autoencoders above.
    Each autoencoder is trained independently and at the same time.
    """

    def __init__(self):
        super(StackedAutoEncoder, self).__init__()

        self.ae1 = CDAutoEncoder(3, 128, 2)
        self.ae2 = CDAutoEncoder(128, 256, 2)
        self.ae3 = CDAutoEncoder(256, 512, 2)
        self.ae3 = CDAutoEncoder(512, 1024, 2)
        self.ae4 = CDAutoEncoder(1024, 2048, 2)

    def forward(self, x):
        a1 = self.ae1(x)
        a2 = self.ae2(a1)
        a3 = self.ae3(a2)
        a4 = self.ae4(a3)
        a5 = self.ae4(a4)

        if self.training:
            return a5

        else:
            return a5, self.reconstruct(a5)

    def reconstruct(self, x):
        a4_reconstruct = self.ae5.reconstruct(x)
        a3_reconstruct = self.ae4.reconstruct(x)
        a2_reconstruct = self.ae3.reconstruct(x)
        a1_reconstruct = self.ae2.reconstruct(a2_reconstruct)
        x_reconstruct = self.ae1.reconstruct(a1_reconstruct)
        return x_reconstruct

model = StackedAutoEncoder()
#model = StackedAutoEncoder().to(device)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    if epoch % 10 == 0:
        # Test the quality of our features with a randomly initialzed linear classifier.
        classifier = nn.Sequential(nn.Linear(6*2056*4*4, 1000),
                                  nn.ReLU(),
                                  nn.Linear(1000, 800*800)).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    model.train()
    total_time = time.time()
    correct = 0
    for i, data in enumerate(trainloader):
        sample, target, road_image, extra = data
        road_image = 1*torch.stack(road_image).to(device).reshape(-1,800*800)
        batch_size = road_image.size(0)
        sample = torch.stack(sample).reshape(6,-1,3,256,256).to(device)
        features = []
        for img in sample:
            features.append(model(img).detach().view(features.size(0), -1))
        features = torch.cat(features, axis=1)
        prediction = classifier(features)
        loss = criterion(prediction, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = prediction.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    total_time = time.time() - total_time

    model.eval()
    img, _ = data
    img = Variable(img).cuda()
    features, x_reconstructed = model(img)
    reconstruction_loss = torch.mean((x_reconstructed.data - img.data)**2)

    if epoch % 10 == 0:
        print("Saving epoch {}".format(epoch))
        orig = to_img(img.cpu().data)
        save_image(orig, './imgs/orig_{}.png'.format(epoch))
        pic = to_img(x_reconstructed.cpu().data)
        save_image(pic, './imgs/reconstruction_{}.png'.format(epoch))

    print("Epoch {} complete\tTime: {:.4f}s\t\tLoss: {:.4f}".format(epoch, total_time, reconstruction_loss))
    print("Feature Statistics\tMean: {:.4f}\t\tMax: {:.4f}\t\tSparsity: {:.4f}%".format(
        torch.mean(features.data), torch.max(features.data), torch.sum(features.data == 0.0)*100 / features.data.numel())
    )
    print("Linear classifier performance: {}/{} = {:.2f}%".format(correct, len(dataloader)*batch_size, 100*float(correct) / (len(dataloader)*batch_size)))
    print("="*80)

torch.save(model.state_dict(), './CDAE.pth')

In [None]:
class Encoder(nn.Module):
    def __init__(self,n_feature, hidden):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                           out_channels=n_feature, 
                           kernel_size=(3,5), 
                           stride=2)
        self.maxp1 = nn.MaxPool2d(kernel_size = 2, stride=2)
        self.conv1_bn = nn.BatchNorm2d(n_feature)
        
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=64, 
                               kernel_size=(3,5), 
                               stride=1,
                               padding=(0,0))
        self.maxp2 = nn.MaxPool2d(kernel_size = 2, stride=2)
        self.conv2_bn = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=96, 
                               kernel_size=(3,3), 
                               stride=1,
                               padding=(0,0))
        self.maxp3 = nn.MaxPool2d(kernel_size = 2, stride=2)
        self.conv3_bn = nn.BatchNorm2d(96)
        
        self.lin1 = nn.Linear(96*14*16,hidden)
        self.lin1_bn = nn.BatchNorm1d(hidden)
        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [F.relu(self.conv1_bn(self.maxp1(self.conv1(y)))) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = F.relu(self.conv2_bn(self.maxp1(self.conv2(x))))
        x = F.relu(self.conv3_bn(self.maxp1(self.conv3(x))))
        x = F.relu(self.lin1_bn(self.lin1(x.reshape(-1,96*14*16))))
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden, hidden_img):
        super(Decoder, self).__init__()
        self.hidden_img = hidden_img

        self.lin1 = nn.Linear(hidden,6*3*hidden_img*int(1.2*(hidden_img)))
        self.lin1_bn = nn.BatchNorm1d(6*3*hidden_img*int(1.2*(hidden_img)))

        self.convT1 = nn.ConvTranspose2d(in_channels=3, 
                                 out_channels=3, 
                                 kernel_size=(3,3), 
                                 padding = (3,4),
                                 stride=2, 
                                 dilation=(1,1),
                                 output_padding=(0,0))
        
        self.convT1_bn = nn.BatchNorm2d(3)

        self.convT2 = nn.ConvTranspose2d(in_channels=3, 
                                 out_channels=3, 
                                 kernel_size=(3,3), 
                                 padding = (3,3),
                                 stride=2, 
                                 dilation=(1,1),
                                 output_padding=(0,0))
        
        self.convT2_bn = nn.BatchNorm2d(3)

        self.convT3 = nn.ConvTranspose2d(in_channels=3, 
                                 out_channels=3, 
                                 kernel_size=(3,3), 
                                 padding = (2,1),
                                 stride=2, 
                                 dilation=(1,1),
                                 output_padding=(1,1))
        
    def forward(self, x, verbose=False):
        x = F.relu(self.lin1_bn(self.lin1(x)).reshape(6,-1,3,hidden_img,int(1.2*hidden_img)))
        x = [y for y in x]
        x = [self.convT3(
            self.convT2_bn(self.convT2(
            self.convT1_bn(self.convT1(y))))) for y in x]
        x = tuple(x)
        x = torch.stack(x)
        return x

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, n_features=32, hidden=1000, hidden_img=36):
        super().__init__()
        self.encoder = Encoder(n_features, hidden)
        self.decoder = Decoder(hidden, hidden_img)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
def train(epoch, model, criterion, optimizer, batch_size,show_photo=False):
    model.train()
    for batch_idx, (sample, target, road_image, extra) in enumerate(trainloader):
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        img = sample 
        img = img.to(device)
        # img = img.view(img.size(0), -1)
        # noise = do(torch.ones(img.shape))
        # img_bad = (img * noise).to(device)  # comment out for standard AE
        # ===================forward=====================
        output = model(img)  # feed <img> (for std AE) or <img_bad> (for denoising AE)
        loss = criterion(output, img.data)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss.item():.4f}')
    plt_out = output.reshape(-1,6,3,256,306)
    plt.imshow(torchvision.utils.make_grid(plt_out[0].detach().cpu()
                                           , nrow=3).numpy().transpose(1, 2, 0))
    plt.show()

In [None]:
convT1 = nn.ConvTranspose2d(in_channels=128, 
                                 out_channels=96, 
                                 kernel_size=1, 
                                 padding = (2,2),
                                 stride=2, 
                                 dilation=(1,1),
                                 output_padding=(0,0))

## convT1(x).shape ## can add conv3(x) as a skip connection here

convT2 = nn.ConvTranspose2d(in_channels=96, 
                                 out_channels=96, 
                                 kernel_size=1, 
                                 padding = (4,4),
                                 stride = 3, 
                                 dilation=(1,1),
                                 output_padding=(0,0))

convT3 = nn.ConvTranspose2d(in_channels=96, 
                                 out_channels=96, 
                                 kernel_size=3, 
                                 padding = (2,2),
                                 stride = 2, 
                                 dilation=(1,1),
                                 output_padding=(1,1))

convT4 = nn.ConvTranspose2d(in_channels=96, 
                                 out_channels=64, 
                                 kernel_size=2, 
                                 padding = (0,0),
                                 stride = 2, 
                                 dilation=(1,1),
                                 output_padding=(0,0))

convT5 = nn.ConvTranspose2d(in_channels=64, 
                                 out_channels=2, 
                                 kernel_size=2, 
                                 padding = (0,0),
                                 stride = 2, 
                                 dilation=(1,1),
                                 output_padding=(0,0))

In [None]:
class SemSeg(nn.Module):
    def __init__(self,n_feature):
        super(SemSeg, self).__init__()
        self.conv1 = conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=2, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv3 = nn.Conv2d(in_channels=2, 
                               out_channels=2, 
                               kernel_size=(1,5), 
                               stride=2,
                               padding=(2,0))
        self.conv4 = nn.Conv2d(in_channels=2, 
                               out_channels=2, 
                               kernel_size=(1,1), 
                               stride=2,
                               padding=(2,2))
        self.convT1 = nn.ConvTranspose2d(in_channels=2, 
                                         out_channels=2, 
                                         kernel_size=1, 
                                         stride=42, 
                                         dilation=1,
                                         output_padding=1)
        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [self.conv1(y) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.convT1(x)
        return x

In [None]:
class SemSegMulti(nn.Module):
    def __init__(self,n_feature):
        super(SemSegMulti, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=50, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv3 = nn.Conv2d(in_channels=50, 
                               out_channels=25, 
                               kernel_size=(1,5), 
                               stride=2,
                               padding=(2,0))
        self.conv4 = nn.Conv2d(in_channels=25, 
                               out_channels=15, 
                               kernel_size=(1,1), 
                               stride=2,
                               padding=(2,2))
        self.convT1 = nn.ConvTranspose2d(in_channels=15, 
                                         out_channels=10, 
                                         kernel_size=1, 
                                         stride=42, 
                                         dilation=1,
                                         output_padding=1)
        self.convT2 = nn.ConvTranspose2d(in_channels=15, 
                                         out_channels=2, 
                                         kernel_size=1, 
                                         stride=42, 
                                         dilation=1,
                                         output_padding=1)
        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [self.conv1(y) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return [self.convT1(x),
                self.convT2(x)]

In [None]:
class SemSegMulti2(nn.Module):
    def __init__(self,n_feature):
        super(SemSegMulti2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv1_bn = nn.BatchNorm2d(n_feature)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=64, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=96, 
                               kernel_size=(1,3), 
                               stride=2,
                               padding=(4,0))
        self.conv3_bn = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(in_channels=96, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(1,1))
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv5_bn = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv6_bn = nn.BatchNorm2d(128)
        
        self.convT1 = nn.ConvTranspose2d(in_channels=128, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (2,2),
                                         stride=2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT1_bn = nn.BatchNorm2d(96)
        ## convT1(x).shape ## can add conv3(x) as a skip connection here

        self.convT2 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (4,4),
                                         stride = 3, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT2_bn = nn.BatchNorm2d(96)
        self.convT3 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=3, 
                                         padding = (2,2),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(1,1))
        self.convT3_bn = nn.BatchNorm2d(96)
        self.convT4 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=64, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT4_bn = nn.BatchNorm2d(64)
        self.convT5 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=2, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        
        self.convT6 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=10, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [F.relu(self.conv1_bn(self.conv1(y))) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x1 = F.relu(self.conv4_bn(self.conv4(x)))
        x1 = F.relu(self.conv5_bn(self.conv5(x1)))
        x1 = F.relu(self.conv6_bn(self.conv6(x1)))
        x1 = F.relu(self.convT1_bn(self.convT1(x1))) + x
        return [
            self.convT6(
            F.relu(self.convT4_bn(self.convT4(
            F.relu(self.convT3_bn(self.convT3(
            F.relu(self.convT2_bn(self.convT2(x1)))))
            ))))),
            self.convT5(
            F.relu(self.convT4_bn(self.convT4(
            F.relu(self.convT3_bn(self.convT3(
            F.relu(self.convT2_bn(self.convT2(x1)))))
            )))))
            ]

In [None]:
class SemSegMultiBB(nn.Module):
    def __init__(self,n_feature):
        super(SemSegMultiBB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv1_bn = nn.BatchNorm2d(n_feature)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=64, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=96, 
                               kernel_size=(1,3), 
                               stride=2,
                               padding=(4,0))
        self.conv3_bn = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(in_channels=96, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(1,1))
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv5_bn = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv6_bn = nn.BatchNorm2d(128)
        
        self.convT1 = nn.ConvTranspose2d(in_channels=128, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (2,2),
                                         stride=2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT1_bn = nn.BatchNorm2d(96)
        ## convT1(x).shape ## can add conv3(x) as a skip connection here

        self.convT2 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (4,4),
                                         stride = 3, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT2_bn = nn.BatchNorm2d(96)
        self.convT3 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=3, 
                                         padding = (2,2),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(1,1))
        self.convT3_bn = nn.BatchNorm2d(96)
        self.convT4 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=64, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT4_bn = nn.BatchNorm2d(64)
        self.convT5 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=2, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        
        self.convT6 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=2, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [F.relu(self.conv1_bn(self.conv1(y))) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x1 = F.relu(self.conv4_bn(self.conv4(x)))
        x1 = F.relu(self.conv5_bn(self.conv5(x1)))
        x1 = F.relu(self.conv6_bn(self.conv6(x1)))
        x1 = F.relu(self.convT1_bn(self.convT1(x1))) + x
        return [
            self.convT6(
            F.relu(self.convT4_bn(self.convT4(
            F.relu(self.convT3_bn(self.convT3(
            F.relu(self.convT2_bn(self.convT2(x1)))))
            ))))),
            self.convT5(
            F.relu(self.convT4_bn(self.convT4(
            F.relu(self.convT3_bn(self.convT3(
            F.relu(self.convT2_bn(self.convT2(x1)))))
            )))))
            ]

In [None]:
class SemSegMultiBB2(nn.Module):
    def __init__(self,n_feature):
        super(SemSegMultiBB2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv1_bn = nn.BatchNorm2d(n_feature)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=64, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=96, 
                               kernel_size=(1,3), 
                               stride=2,
                               padding=(4,0))
        self.conv3_bn = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(in_channels=96, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(1,1))
        self.conv4_bn = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv5_bn = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv6_bn = nn.BatchNorm2d(128)
        
        self.convT1 = nn.ConvTranspose2d(in_channels=128, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (2,2),
                                         stride=2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT1_bn = nn.BatchNorm2d(96)
        ## convT1(x).shape ## can add conv3(x) as a skip connection here

        self.convT2 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (4,4),
                                         stride = 3, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT2_bn = nn.BatchNorm2d(96)
        self.convT3 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=3, 
                                         padding = (2,2),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(1,1))
        self.convT3_bn = nn.BatchNorm2d(96)
        self.convT4 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=64, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        self.convT4_bn = nn.BatchNorm2d(64)
        self.convT5 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=2, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        
        self.convT6 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=2, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [F.relu(self.conv1_bn(self.conv1(y))) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x1 = F.relu(self.conv4_bn(self.conv4(x)))
        x1 = F.relu(self.conv5_bn(self.conv5(x1)))
        x1 = F.relu(self.conv6_bn(self.conv6(x1)))
        x1 = F.relu(self.convT1_bn(self.convT1(x1))) + x
        return [
            self.convT6(
            F.relu(self.convT4_bn(self.convT4(
            F.relu(self.convT3_bn(self.convT3(
            F.relu(self.convT2_bn(self.convT2(x1)))))
            ))))),
            self.convT5(
            F.relu(self.convT4_bn(self.convT4(
            F.relu(self.convT3_bn(self.convT3(
            F.relu(self.convT2_bn(self.convT2(x1)))))
            )))))
            ]

In [None]:
class SemSegVAE(nn.Module):
    def __init__(self,n_feature):
        super(SemSegVAE, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=64, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=96, 
                               kernel_size=(1,3), 
                               stride=2,
                               padding=(4,0))
        self.conv4 = nn.Conv2d(in_channels=96, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(1,1))
        self.conv5 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        self.conv6 = nn.Conv2d(in_channels=128, 
                               out_channels=128, 
                               kernel_size=(7,7), 
                               stride=1,
                               padding=(0,0))
        
        self.encoder = nn.Sequential(
            nn.Linear(784, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, d * 2)
        )
        
        self.convT1 = nn.ConvTranspose2d(in_channels=128, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (2,2),
                                         stride=2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        ## convT1(x).shape ## can add conv3(x) as a skip connection here

        self.convT2 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=1, 
                                         padding = (4,4),
                                         stride = 3, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        self.convT3 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=96, 
                                         kernel_size=3, 
                                         padding = (2,2),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(1,1))

        self.convT4 = nn.ConvTranspose2d(in_channels=96, 
                                         out_channels=64, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        self.convT5 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=2, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))
        
        self.convT6 = nn.ConvTranspose2d(in_channels=64, 
                                         out_channels=10, 
                                         kernel_size=2, 
                                         padding = (0,0),
                                         stride = 2, 
                                         dilation=(1,1),
                                         output_padding=(0,0))

        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [self.conv1(y) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = self.conv2(x)
        x = self.conv3(x)
        x1 = self.conv4(x)
        x1 = self.conv5(x1)
        x1 = self.conv6(x1)
        x1 = self.convT1(x1) + x
        return [self.convT6(self.convT4(self.convT3(self.convT2(x1)))),
                self.convT5(self.convT4(self.convT3(self.convT2(x1))))]

In [None]:
def train(epoch, model, criterion, optimizer, batch_size):
    model.train()
    for batch_idx, (sample, target, road_image, extra) in enumerate(trainloader):
        # send to device
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        road_image = 1*torch.stack(road_image).to(device)
        output = model(sample)        
        optimizer.zero_grad()
        loss = criterion(output, road_image)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 
                batch_idx * len(sample), len(trainloader.dataset),
                #batch_idx * len(sample), len(trainloader),
                100. * batch_idx * len(sample) / len(trainloader.dataset), 
                loss.item()))

In [None]:
accuracy_list = []
def test(model, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    for batch_idx, (sample, target, road_image, extra) in enumerate(testloader):
        # send to device
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        road_image = 1*torch.stack(road_image).to(device)
        
        output = model(sample) 
        test_loss += criterion(output, road_image).item() # sum up batch loss                                                               
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability                                                                 
        correct += pred.eq(road_image.data.view_as(pred)).cpu().sum().item()

    test_loss /= len(testloader.dataset)
    accuracy = 100. * correct / (16*800*800*len(testloader))
    accuracy_list.append(accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, (batch_size*800*800*len(testloader)),
        accuracy))

In [None]:
def trainMulti(epoch, model, criterion1, criterion2, optimizer, batch_size):
    model.train()
    for batch_idx, (sample, target, road_image, extra) in enumerate(trainloader):
        # send to device
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        road_image = 1*torch.stack(road_image).to(device)
        batch_size = sample.size(1)
        y_target = torch.zeros((batch_size,10,800,800))
        for i in range(batch_size):
            for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
                y_target[i,cat+1,:,:] += 1*convert_to_binary_mask(bb)
        y_target = y_target.to(device)
        values, indices = torch.max(y_target,1)
        y_target = torch.max(y_target,1, keepdim=True)
        y_targ = torch.zeros_like(values, dtype=torch.long)
        y_targ[values > 0] = indices[values > 0]
        y_target = y_targ
        output = model(sample) 
        optimizer.zero_grad()
        #print(y_target.unique())
        loss1 = criterion1(output[0], y_target)
        #print(road_image.unique())
        loss2 = criterion2(output[1], road_image)
        loss = loss1 + loss2
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 
                batch_idx * batch_size, len(trainloader.dataset),
                #batch_idx * len(sample), len(trainloader),
                100. * batch_idx * batch_size / len(trainloader.dataset), 
                loss.item()))

accuracy_list = []
def testMulti(model, criterion1, criterion2):
    model.eval()
    test_loss = 0
    road_correct = 0
    other_correct = 0
    total_road = 0
    total_other = 0
    conf_matrix_road = torch.zeros(2, 2)
    conf_matrix_other = torch.zeros(10, 10)
    for batch_idx, (sample, target, road_image, extra) in enumerate(testloader):
        # send to device
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        road_image = 1*torch.stack(road_image).to(device)
        batch_size = sample.size(1)
        y_target = torch.zeros((batch_size,10,800,800))
        for i in range(batch_size):
            for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
                y_target[i,cat+1,:,:] += 1*convert_to_binary_mask(bb)
        y_target = y_target.to(device)
        values, indices = torch.max(y_target,1)
        y_target = torch.max(y_target,1, keepdim=True)
        y_targ = torch.zeros_like(values, dtype=torch.long)
        y_targ[values > 0] = indices[values > 0]
        y_target = y_targ
        
        output = model(sample)
        
        loss1 = criterion1(output[0], y_target)
        #print(road_image.unique())
        loss2 = criterion2(output[1], road_image)
        loss = loss1 + loss2
        
        test_loss += loss.item() # sum up batch loss                                                               
        pred_road = output[1].data.max(1, keepdim=True)[1] # get the index of the max log-probability  
        pred_other = output[0].data.max(1, keepdim=True)[1] # get the index of the max log-probability  
        road_correct += pred_road.eq(road_image.data.view_as(pred_road)).cpu().sum().item()
        other_correct += pred_other.eq(y_target.data.view_as(pred_other)).cpu().sum().item()
        total_road += road_image.nelement()
        total_other += y_target.nelement()
                
        conf_matrix_road = create_conf_matrix2(road_image, pred_road)
        conf_matrix_other = create_conf_matrix2(y_target, pred_other)
                
    test_loss /= len(testloader.dataset)
    road_accuracy = 100. * road_correct / total_road
    other_accuracy = 100. * other_correct / total_other
    accuracy_list.append((road_accuracy + other_accuracy)/2)
    print("""\nTest set: Average loss: {:.4f}, 
    Accuracy Road: {}/{} ({:.0f}%) , 
    Accuracy Other: {}/{}, ({:.0f}%),
    Road: TP {} , 
    TN {}
    FP {}
    FN {},
    Other: TP {} 
    TN {} 
    FP {}
    FN {}
    \n""".format(
        test_loss, road_correct, total_road, road_accuracy,
        other_correct, total_other, other_accuracy, 
        *classScores(conf_matrix_road),
        *classScores(conf_matrix_other)))
    
        

In [None]:
weights1 = [0.1, 1.0]
class_weights1 = torch.FloatTensor(weights1).cuda()

criterion1 = nn.CrossEntropyLoss(reduction='mean',weight=class_weights1)

weights = [0.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion2 = nn.CrossEntropyLoss(reduction='mean')
model = SemSegMulti2(32)
model.to(device)
learning_rate = 1e-2
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
print('Number of parameters: {}'.format(get_n_params(model)))

for epoch in range(0, 50):
    trainMulti(epoch, model, criterion1, criterion2, optimizer, batch_size)
    testMulti(model, criterion1, criterion2)
    scheduler.step(accuracy_list[epoch])

In [None]:
def trainBB(epoch, model, criterion1, criterion2, optimizer, batch_size):
    model.train()
    for batch_idx, (sample, target, road_image, extra) in enumerate(trainloader):
        # send to device
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        road_image = 1*torch.stack(road_image).to(device)
        batch_size = sample.size(1)
        y_target = torch.zeros((batch_size,1,800,800))
        for i in range(batch_size):
            for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
                y_target[i,0,:,:] += 1*convert_to_binary_mask(bb)
        y_target = 1*(y_target>0)
        y_target = y_target.to(device).squeeze()
        output = model(sample) 
        optimizer.zero_grad()
        #print(y_target.unique())
        loss1 = criterion1(output[0], y_target)
        #print(road_image.unique())
        loss2 = criterion2(output[1], road_image)
        loss = loss1 + loss2
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 
                batch_idx * batch_size, len(trainloader.dataset),
                #batch_idx * len(sample), len(trainloader),
                100. * batch_idx * batch_size / len(trainloader.dataset), 
                loss.item()))

accuracy_list = []
def testBB(model, criterion1, criterion2):
    model.eval()
    test_loss = 0
    road_correct = 0
    other_correct = 0
    total_road = 0
    total_other = 0
    conf_matrix_road = torch.zeros(2, 2)
    conf_matrix_other = torch.zeros(10, 10)
    for batch_idx, (sample, target, road_image, extra) in enumerate(testloader):
        # send to device
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        road_image = 1*torch.stack(road_image).to(device)
        batch_size = sample.size(1)
        y_target = torch.zeros((batch_size,1,800,800))
        for i in range(batch_size):
            for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
                y_target[i,0,:,:] += 1*convert_to_binary_mask(bb)
        y_target = 1*(y_target>0)
        y_target = y_target.to(device).squeeze()
        
        output = model(sample)
        
        loss1 = criterion1(output[0], y_target)
        #print(road_image.unique())
        loss2 = criterion2(output[1], road_image)
        loss = loss1 + loss2
        
        test_loss += loss.item() # sum up batch loss                                                               
        pred_road = output[1].data.max(1, keepdim=True)[1] # get the index of the max log-probability  
        pred_other = output[0].data.max(1, keepdim=True)[1] # get the index of the max log-probability  
        road_correct += pred_road.eq(road_image.data.view_as(pred_road)).cpu().sum().item()
        other_correct += pred_other.eq(y_target.data.view_as(pred_other)).cpu().sum().item()
        total_road += road_image.nelement()
        total_other += y_target.nelement()
                
        conf_matrix_road = create_conf_matrix2(road_image, pred_road)
        conf_matrix_other = create_conf_matrix2(y_target, pred_other)
                
    test_loss /= len(testloader.dataset)
    road_accuracy = 100. * road_correct / total_road
    other_accuracy = 100. * other_correct / total_other
    accuracy_list.append((road_accuracy + other_accuracy)/2)
    print("""\nTest set: Average loss: {:.4f}, 
    Accuracy Road: {}/{} ({:.0f}%) , 
    Accuracy Other: {}/{}, ({:.0f}%),
    Road: TP {} , 
    TN {}
    FP {}
    FN {},
    Other: TP {} 
    TN {} 
    FP {}
    FN {}
    \n""".format(
        test_loss, road_correct, total_road, road_accuracy,
        other_correct, total_other, other_accuracy, 
        *classScores(conf_matrix_road),
        *classScores(conf_matrix_other)))
    

weights1 = [0.6, 1.0]
class_weights1 = torch.FloatTensor(weights1).cuda()

criterion1 = nn.CrossEntropyLoss(reduction='mean',weight=class_weights1)

weights2 = [0.02, 1.0]
class_weights2 = torch.FloatTensor(weights2).cuda()
criterion2 = nn.CrossEntropyLoss(reduction='mean',weight=class_weights2)
model = SemSegMultiBB(32)
model.to(device)
learning_rate = 1e-2
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
print('Number of parameters: {}'.format(get_n_params(model)))

for epoch in range(0, 50):
    trainBB(epoch, model, criterion1, criterion2, optimizer, batch_size)
    testBB(model, criterion1, criterion2)
    scheduler.step(accuracy_list[epoch])

In [None]:
sample, target, road_image, extra = iter(testloader).next()

In [None]:
y_target = torch.zeros((batch_size,1,800,800))
for i in range(batch_size):
    for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
        y_target[i,0,:,:] += 1*convert_to_binary_mask(bb)

In [None]:
(y_target>0).sum()

In [None]:
y_target = y_target>0

In [None]:
for batch_idx, (sample, target, road_image, extra) in enumerate(testloader):
    print(batch_idx, target[0]['category'])
    batch_size = 1
    y_target = torch.zeros((batch_size,9,800,800))
    for i in range(batch_size):
        for cat, bb in zip(target[i]['category'], target[i]['bounding_box']):
            y_target[i,cat+1,:,:] += 1*convert_to_binary_mask(bb)
            print(i, cat, y_target[i,cat+1,:,:].sum())
    y_target = y_target.to(device)
    values, indices = torch.max(y_target,1)
    y_target = torch.max(y_target,1, keepdim=True)
    y_targ = torch.zeros_like(values, dtype=torch.long)
    y_targ[values > 0] = indices[values > 0]
    print(y_targ)
    print(y_targ.sum())
    break

In [None]:
# The shape of bounding box is [batch_size, N (the number of object), 2, 4]
print(target[0]['bounding_box'].shape)

In [None]:
# All bounding box are retangles
# Each bounding box is organized with four corners of the box
# All the values are in meter and bounded by 40 meters, and the origin is the center of ego car
# the order of the four courners are front left, front right, back left and back right
print(target[0]['bounding_box'][0])

In [None]:
corners = target[0]['bounding_box'][11]

In [None]:
corners

In [None]:
point_squence = torch.stack([corners[:, 0], corners[:, 1], corners[:, 3], corners[:, 2], corners[:, 0]])
x,y = point_squence.T[0].detach() * 10 + 400, -point_squence.T[1].detach() * 10 + 400

In [None]:
plt.figure()

for i, bb in enumerate(target[0]['bounding_box']):
    ax= plt.subplot(6,5 ,i+1)
    new_im = convert_to_binary_mask(bb)
    ax.imshow(new_im, cmap='binary')

In [None]:
list(zip(x,y))

In [None]:
# Each bounding box has a category
# 'other_vehicle': 0,
# 'bicycle': 1,
# 'car': 2,
# 'pedestrian': 3,
# 'truck': 4,
# 'bus': 5,
# 'motorcycle': 6,
# 'emergency_vehicle': 7,
# 'animal': 8
print(target[0]['category'])

In [None]:
sem_seg_target = np.zeros((10,800,800))

for cat, bb in zip(target[0]['category'], target[0]['bounding_box']):
    sem_seg_target[cat,:,:] += convert_to_binary_mask(bb)

In [None]:
plt.figure()
for i in range(9):
    ax=plt.subplot(3,3,i+1)
    ax.imshow(sem_seg_target[i,:,:]==1, cmap='binary')

## Road Map Layout


In [None]:
road_image[0]

In [None]:
# The road map layout is encoded into a binary array of size [800, 800] per sample 
# Each pixel is 0.1 meter in physiscal space, so 800 * 800 is 80m * 80m centered at the ego car
# The ego car is located in the center of the map (400, 400) and it is always facing the left

fig, ax = plt.subplots()

ax.imshow(road_image[0], cmap='binary');

In [None]:
print(road_image[0].shape)

In [None]:
print(road_image[0])

## Extra Info

There is some extra information you can use in your model, but it is optional.

In [None]:
# Action
# Action is the label that what the object is doing

# 'object_action_parked': 0,
# 'object_action_driving_straight_forward': 1,
# 'object_action_walking': 2,
# 'object_action_running': 3,
# 'object_action_lane_change_right': 4,
# 'object_action_stopped': 5,
# 'object_action_left_turn': 6,
# 'object_action_right_turn': 7,
# 'object_action_sitting': 8,
# 'object_action_standing': 9,
# 'object_action_gliding_on_wheels': 10,
# 'object_action_abnormal_or_traffic_violation': 11,
# 'object_action_lane_change_left': 12,
# 'object_action_other_motion': 13,
# 'object_action_reversing': 14,
# 'object_action_u_turn': 15,
# 'object_action_loss_of_control': 16

In [None]:
print(extra[0]['action'])

In [None]:
# Ego Image
# A more detailed ego image
fig, ax = plt.subplots()

ax.imshow(extra[0]['ego_image'].numpy().transpose(1, 2, 0));

In [None]:
# Lane Image
# Binary lane image
fig, ax = plt.subplots()

ax.imshow(extra[0]['lane_image'], cmap='binary');

# Visualize the bounding box

In [None]:
target[0]

In [None]:
target[0]['bounding_box']

In [None]:
# The center of image is 400 * 400

fig, ax = plt.subplots()

color_list = ['b', 'g', 'orange', 'c', 'm', 'y', 'k', 'w', 'r']

ax.imshow(road_image[0], cmap ='binary');

# The ego car position
ax.plot(400, 400, 'x', color="red")

for i, bb in enumerate(target[0]['bounding_box']):
    # You can check the implementation of the draw box to understand how it works 
    draw_box(ax, bb, color=color_list[target[0]['category'][i]])    

## Object Detection

In [None]:
class ObjDet(nn.Module):
    def __init__(self,n_feature, n_categories, n_boxes):
        super(ObjDet, self).__init__()
        self.n_categories = n_categories
        self.n_boxes = n_boxes
        self.conv1 = conv1 = nn.Conv2d(in_channels=3, 
                                       out_channels=n_feature, 
                                       kernel_size=(3,7), 
                                       stride=2)
        self.conv1_bn = nn.BatchNorm2d(n_feature)
        self.conv2 = nn.Conv2d(in_channels=n_feature*6, 
                               out_channels=64, 
                               kernel_size=(3,7), 
                               stride=2,
                               padding=(2,3))
        self.conv2_bn = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, 
                               out_channels=96, 
                               kernel_size=(1,5), 
                               stride=2,
                               padding=(2,0))
        self.conv3_bn = nn.BatchNorm2d(96)
        self.conv4 = nn.Conv2d(in_channels=96, 
                               out_channels=96, 
                               kernel_size=(1,1), 
                               stride=2,
                               padding=(2,2))
        self.conv4_bn = nn.BatchNorm2d(96)
        self.lin2 = nn.Linear(96*20*20, self.n_boxes*8)
        
    def forward(self, x, verbose=False):
        x = [y for y in x]
        x = [F.relu(self.conv1_bn(self.conv1(y))) for y in x]
        x = tuple(x)
        x = torch.cat(x,axis=1)
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = F.relu(self.conv4_bn(self.conv4(x)))
        return self.lin2(x.reshape(-1,96*20*20)).reshape(-1,2,self.n_boxes*4)

In [None]:
n_cat = 10
n_boxes = 50
n_feature = 20

obj = ObjDet(n_feature, n_cat, n_boxes)

cat_prob, xy = obj(torch.stack(sample)[0].reshape(6,-1,3,256,306))

cat_prob = cat_prob.reshape(-1,n_cat)
cat_prob.shape

bb = xy.reshape(-1,2,4)
bb.shape

target_c = target[0]['category']
print(target_c.shape)
target_bb = target[0]['bounding_box'].reshape(-1,2,4)
print(target_bb.shape)

bb[0,:,:] = target_bb[3,:,:]
bb[1,:,:] = target_bb[1,:,:]

sample, target, road_image, extra = iter(trainloader).next()

sample = torch.stack(sample)

sample = sample.reshape(6,-1,3,256,306)

batch_size = sample.size(1)

pred = obj(sample)

pred = tuple(zip(pred[0].reshape(batch_size,-1,10), pred[1].reshape(batch_size,-1,2,4)))

targ = tuple([(targ_i['category'],targ_i['bounding_box']) for targ_i in target])

loss = ssd_loss(pred,targ)
loss.backward()
optimizer.step()

In [None]:
def trainObjDet(epoch, model, optimizer):
    model.train()
    for batch_idx, (sample, target, road_image, extra) in enumerate(trainloader):
        sample = torch.stack(sample).reshape(6,-1,3,256,306).to(device)
        batch_size = sample.size(1)
        pred = model(sample)
        pred = tuple([bb for bb in pred.reshape(batch_size,-1,2,4)])
        targ = tuple([targ_i['bounding_box'].to(device) for targ_i in target])
        loss = ssd_loss(pred,targ)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 
                batch_idx * len(sample), len(trainloader.dataset),
                #batch_idx * len(sample), len(trainloader),
                100. * batch_idx * len(sample) / len(trainloader.dataset), 
                loss.item()))

In [None]:
n_cat = 10
n_boxes = 50
n_feature = 20


def ssd_1_loss(b_bb,
               bbox,
               print_it=False):
    overlaps = calculate_overlap(bbox.data, b_bb.data)
    gt_overlap,gt_idx = map_to_ground_truth(overlaps,print_it)
    #gt_clas = clas[gt_idx]
    pos = gt_overlap > 0.4
    pos_idx = torch.nonzero(pos)[:,0]
    #gt_clas[~pos] = n_cat
    gt_bbox = bbox[gt_idx]
    loc_loss = ((b_bb[pos_idx] - gt_bbox[pos_idx]).abs()).mean()
    #clas_loss  = loss_f(b_c, gt_clas)
    #return loc_loss, clas_loss
    return loc_loss
    
def ssd_loss(pred,targ,print_it=False):
    lls = 0.
    for b, t in zip(pred,targ):
        b_bb = b
        bbox = t
        #print(b_c.shape,b_bb.shape)
        #print(clas.shape,bbox.shape)
        loc_loss = ssd_1_loss(b_bb,
                              bbox,
                              print_it)
        lls += loc_loss
    if print_it: print(f'loc: {lls.data[0]}')
    return lls

In [None]:
model = ObjDet(n_feature, n_cat, n_boxes)
model.to(device)
learning_rate = 1e-5
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)
#scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
print('Number of parameters: {}'.format(get_n_params(model)))

for epoch in range(0, 5):
    trainObjDet(epoch, model, optimizer)
    #testObjDet(model, criterion, batch_size)
    #scheduler.step(accuracy_list[epoch])

# Evaluation
During the whole competition, you have three submission deadlines. The dates will be announced on Piazza. You will have to fill up the template 'data_loader.py' for evaluation. (see the comment inside data_loader.py' for more information)

There will be two leaderboards for the competition:
The leaderboard for binary road map.
We will evaluate your model's performance by using the average threat score (TS) across the test set:
$$\text{TS} = \frac{\text{TP}}{\text{TP} + \text{FP} + \text{FN}}$$
The leaderboard for object detection:
We will evaluate your model's performance for object detection by using the average mean threat score at different intersection over union (IoU) thresholds.
There will be five different thresholds (0.5, 0.6, 0.7, 0.8, 0.9). For each thresholds, we will calculate the threat score. The final score will be a weighted average of all the threat scores:
$$\text{Final Score} = \sum_t \frac{1}{t} \cdot \frac{\text{TP}(t)}{\text{TP}(t) + \text{FP}(t) + \text{FN}(t)}$$
