In [None]:
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [12,12]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

In [None]:
# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '../data'
annotation_csv = '../data/annotation.csv'

In [None]:
# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.random.choice(np.arange(106, 134), size=28,replace=False)

In [None]:
train_inds = np.random.choice(labeled_scene_index,20,replace=False)
val_inds = np.array([i for i in labeled_scene_index if i not in train_inds])

In [None]:
train_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((256,256)),
                                                  torchvision.transforms.RandomHorizontalFlip(),
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                                                                                   std = [ 0.229, 0.224, 0.225 ])])

val_transform = torchvision.transforms.Compose([torchvision.transforms.Resize((256,256)),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                                                                                  std = [ 0.229, 0.224, 0.225 ])])

labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_inds,
                                  transform=train_transform,
                                  extra_info=True
                                 )

labeled_valset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=val_inds,
                                  transform=val_transform,
                                  extra_info=True
                                 )

trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=1, shuffle=True, num_workers=2, collate_fn=collate_fn)
valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=1, shuffle=True, num_workers=2, collate_fn=collate_fn)

In [None]:
sample, target, road_image, extra = iter(trainloader).next()
print(torch.stack(sample).shape)

In [None]:
plt.imshow(torchvision.utils.make_grid(sample[0], nrow=3).numpy().transpose(1, 2, 0))

In [None]:
fig, ax = plt.subplots()

color_list = ['b', 'g', 'orange', 'c', 'm', 'y', 'k', 'w', 'r']

ax.imshow(road_image[0], cmap ='binary')
ax.plot(400, 400, 'x', color="red")
for i, bb in enumerate(target[0]['bounding_box']):
    # You can check the implementation of the draw box to understand how it works 
    draw_box(ax, bb, color=color_list[target[0]['category'][i]])   

In [None]:
category_map = {'other_vehicle': 0, 'bicycle': 1, 'car': 2, 'pedestrian': 3, 'truck': 4,
                'bus': 5, 'motorcycle': 6, 'emergency_vehicle': 7, 'animal': 8}

## Model

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

In [None]:
class UNet(nn.Module):
    def ConvBlock(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, use_bias = False):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
                                        stride, padding, bias = use_bias),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(True)
                             )
        return block
    
    def Bridge(self, in_channels, out_channels):
        bridge = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                               self.ConvBlock(out_channels, out_channels)
                              )
        return bridge
    
    def UpsampleBlock(self, in_channels, out_channels, use_bias=False):
        upsample = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=use_bias),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU(True))
        return upsample
        
    def UpsampleConv(self, in_channels, out_channels):
        upsample_conv = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                                      self.ConvBlock(out_channels, out_channels))    
        return upsample_conv
        
        
    def __init__(self, num_classes, output_size, encoder='resnet18', pretrained = False, depth = 6):
        '''
        num_classes: Number of channels/classes for segmentation
        output_size: Final output size of the image (H*H)
        encoder: Supports resnet18 and resnet50 architectures
        pretrained: For loading pretrained resnet models as encoders
        '''
        super(UNet,self).__init__()  
        self.depth = depth
        self.num_classes = num_classes
        self.output_size = output_size
        
        self.resnet = torchvision.models.resnet50(pretrained=pretrained) if encoder == "resnet50" else \
                                                    torchvision.models.resnet18(pretrained=pretrained)
        self.resnet_layers = list(self.resnet.children())
        self.n = 2048 if encoder == "resnet50" else 512
        
        self.input_block = nn.Sequential(*self.resnet_layers)[:3]
        self.input_pool = self.resnet_layers[3]
        self.down_blocks = nn.ModuleList([i for i in self.resnet_layers if isinstance(i, nn.Sequential)])

        self.bridge = self.Bridge(self.n, self.n)
        
        self.up_blocks = nn.ModuleList([self.UpsampleBlock(self.n,self.n//2)[0],
                                        self.UpsampleBlock(self.n//2,self.n//4)[0],
                                        self.UpsampleBlock(self.n//4,self.n//8)[0],
                                        self.UpsampleBlock(self.n//8,self.n//16)[0],
                                        self.UpsampleBlock(self.n//16,self.n//32)[0]])
        
        self.up_conv = nn.ModuleList([self.UpsampleConv(self.n,self.n//2),
                                      self.UpsampleConv(self.n//2,self.n//4),
                                      self.UpsampleConv(self.n//4,self.n//8),
                                      self.UpsampleConv(self.n//16 + 64,self.n//16),
                                      self.UpsampleConv(self.n//32 + 3,self.n//32)])
        
        self.final_upsample_1 = self.UpsampleBlock(self.n//32,self.n//64)
        self.final_upsample_2 = self.UpsampleBlock(self.n//64,self.num_classes)
        
        self.final_pooling = nn.AdaptiveMaxPool2d(output_size=self.output_size)

    def forward(self, x):
        skip_conn = {"layer_0": x}
        x = self.input_block(x)
        skip_conn[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i != (self.depth - 1):
                skip_conn[f"layer_{i}"] = x
            
        x = self.bridge(x)
        x = torch.sum(x,dim=0)
        x = x.repeat((6,1,1,1))

        for i, block in enumerate(self.up_blocks):
            key = f"layer_{self.depth - i - 2}"            
            x = block(x)
            x = torch.cat([x, skip_conn[key]],1)
            x = self.up_conv[i](x)

        del skip_conn
        
        x = torch.sum(x,dim=0)
        x = self.final_upsample_1(x.unsqueeze(0))
        x = self.final_upsample_2(x)
        x = self.final_pooling(x)

        return x.view(-1,self.output_size)

In [None]:
model = UNet(num_classes=1, output_size=800, encoder="resnet18", pretrained = False).to(device)

In [None]:
lr = 1e-2
momentum = 0.5
num_epochs = 10

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)

### Training

In [None]:
def compute_ts_road_map(road_map1, road_map2):
    tp = (road_map1 * road_map2).sum()

    return tp * 1.0 / (road_map1.sum() + road_map2.sum() - tp)

In [None]:
def dice_loss(pred,truth):
    return 1 - (2*torch.sum(pred*truth))/(torch.sum(pred*pred) + torch.sum(truth*truth))

In [None]:
for epoch in range(num_epochs):
    train_loss = 0
    model.train()
    for i, (sample, _, road_image, _) in enumerate(tqdm(trainloader)):        
        
        sample, road_image = sample[0].to(device), road_image[0].float().to(device)
        
        model.zero_grad()
        out = model(sample)
        loss = criterion(out, road_image) + dice_loss(out, road_image)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        if (i+1)%200 == 0:
            print("Epoch: {} | Iter: {} | Train loss: {}".format(epoch+1, i+1, train_loss/(i+1)))
    
    model.eval()
    val_loss = 0
    val_ts = []
    with torch.no_grad():        
        for i, (sample, _, road_image, _) in enumerate(tqdm(valloader)):
            sample, road_image = sample[0].to(device), road_image[0].float().to(device)
            out = model(sample)
            loss = criterion(out, road_image) + dice_loss(out, road_image)
            val_loss += loss.item()
            val_ts.append(compute_ts_road_map(out,road_image).item())

        print("Epoch: {} | Val loss: {} | Val TS: {}".format(epoch+1,val_loss/len(valloader.dataset), np.mean(val_ts)))

### Visualize predictions

In [None]:
sample, target, road_image, extra = iter(trainloader).next()
out = model(sample[0].to(device))

In [None]:
fig = plt.figure(figsize=(8,8))
fig.add_subplot(1,2,1)
plt.imshow((torch.sigmoid(out)>=0.5).float().detach().cpu().numpy(),cmap='binary')
plt.title("Prediction")

fig.add_subplot(1,2,2)
plt.imshow(road_image[0].cpu().numpy(),cmap='binary')
plt.title("Ground Truth")