In [1]:
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [12,12]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image

from data_helper import LabeledDataset,UnlabeledDataset
from helper import collate_fn, draw_box

import itertools
from scipy.spatial.distance import cdist
from helper import compute_ats_bounding_boxes, compute_ts_road_map

In [2]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

In [3]:
# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '../data'
annotation_csv = '../data/annotation.csv'

In [4]:
# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.arange(106, 134)

### Labeled Dataloader 

In [5]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [6]:
labeled_scene_index = np.arange(106, 134)
train_inds = labeled_scene_index[:25]
val_inds = labeled_scene_index[25:]

In [7]:
transform_train = transforms.Compose([torchvision.transforms.Resize((256,256)),
                                      torchvision.transforms.RandomApply([torchvision.transforms.ColorJitter(brightness=[0,1], contrast=[0,1], saturation=[0,1], hue=[-0.5,0.5])],p=0.5),
                                      torchvision.transforms.RandomGrayscale(p=0.5),
                                      torchvision.transforms.RandomHorizontalFlip(p=0.5), 
                                      transforms.ToTensor(), 
                                      #transforms.Normalize([ 0.485, 0.456, 0.406 ],[ 0.229, 0.224, 0.225 ]),
                                      AddGaussianNoise(0., 0.1),
                                    ])
transform_val = transforms.Compose([torchvision.transforms.Resize((256,256)),
                                transforms.ToTensor(), 
                                #transforms.Normalize([ 0.485, 0.456, 0.406 ],[ 0.229, 0.224, 0.225 ])
                                   ])

labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_inds,
                                  transform=transform_train,
                                  extra_info=True
                                 )
labeled_valset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=val_inds,
                                  transform=transform_val,
                                  extra_info=True
                                 )
labeled_trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)
labeled_valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=1,shuffle=False, num_workers=2)

In [8]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

### Model

In [9]:
class Encoder(nn.Module):
    def ConvBlock(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, use_bias = False):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
                                        stride, padding, bias = use_bias),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(True)
                             )
        return block
    
    def Bridge(self, in_channels, out_channels):
        bridge = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                               self.ConvBlock(out_channels, out_channels)
                              )
        return bridge
        
        
    def __init__(self, encoder='resnet34', pretrained = False, depth = 6):
        '''
        num_classes: Number of channels/classes for segmentation
        output_size: Final output size of the image (H*H)
        encoder: Supports resnet18, resnet 34 and resnet50 architectures
        pretrained: For loading a pretrained resnet model as encoder
        '''
        super(Encoder,self).__init__()  
        self.depth = depth        
        self.resnet = torchvision.models.resnet50(pretrained=pretrained) if encoder == "resnet50" else\
                            torchvision.models.resnet34(pretrained=pretrained) if encoder == "resnet34" else\
                            torchvision.models.resnet18(pretrained=pretrained)
        
        self.resnet_layers = list(self.resnet.children())
        self.n = 2048 if encoder == "resnet50" else 512
        
        self.input_block = nn.Sequential(*self.resnet_layers)[:3]
        #self.input_block[0] = nn.Conv2d(18, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.input_pool = self.resnet_layers[3]
        self.down_blocks = nn.ModuleList([i for i in self.resnet_layers if isinstance(i, nn.Sequential)])

        self.bridge_mu = self.Bridge(self.n, self.n)
        self.bridge_logvar = self.Bridge(self.n, self.n)
        
    def reparameterize(self,mu,logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def forward(self,x):
        B = x.shape[0]
        x = x.transpose(0,1)
        
        mu_list = []
        logvar_list = []
        for i in range(6):
            z = self.input_pool(self.input_block(x[i]))
            for block in self.down_blocks:
                z = block(z)
            mu = self.bridge_mu(z)
            logvar = self.bridge_logvar(z)
            mu_list.append(mu)
            logvar_list.append(logvar)
#             h = self.pool(z)
#             h = h.view([B,1,-1])
#             h_list.append(h)
            
#         del h   
        del mu
        del logvar
#         h = torch.cat(h_list,1)
        mu = torch.cat(mu_list,1)
        logvar = torch.cat(logvar_list,1)
#         del h_list
        del mu_list
        del logvar_list
        z = self.reparameterize(mu,logvar)
        return z,mu,logvar #b,512*6,8,8

In [10]:
class Decoder(nn.Module):
    def ConvBlock(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, use_bias = False):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
                                        stride, padding, bias = use_bias),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(True)
                             )
        return block
    
    def Bridge(self, in_channels, out_channels):
        bridge = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                               self.ConvBlock(out_channels, out_channels)
                              )
        return bridge
    
    def UpsampleBlock(self, in_channels, out_channels, use_bias=False):
        upsample = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=use_bias),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU(True))
        return upsample
        
    def UpsampleConv(self, in_channels, out_channels):
        upsample_conv = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                                      self.ConvBlock(out_channels, out_channels))    
        return upsample_conv
        
        
    def __init__(self, classes = 1, depth = 6, output_size=(800,800), input_channels=512):
        '''
        num_classes: Number of channels/classes for segmentation
        output_size: Final output size of the image (H*H)
        encoder: Supports resnet18, resnet 34 and resnet50 architectures
        pretrained: For loading a pretrained resnet model as encoder
        '''
        super(Decoder,self).__init__()  
        self.depth = depth        
        self.num_classes = classes
        self.output_size = output_size
        self.n  = input_channels
        
        self.up_blocks = nn.ModuleList([self.UpsampleBlock(self.n,self.n//2)[0],
                                        self.UpsampleBlock(self.n//2,self.n//4)[0],
                                        self.UpsampleBlock(self.n//4,self.n//8)[0],
                                        self.UpsampleBlock(self.n//8,self.n//16)[0],
                                        self.UpsampleBlock(self.n//16,self.n//32)[0]])
        
        self.up_conv = nn.ModuleList([self.UpsampleConv(self.n//2,self.n//2),
                                      self.UpsampleConv(self.n//4,self.n//4),
                                      self.UpsampleConv(self.n//8,self.n//8),
                                      self.UpsampleConv(self.n//16 ,self.n//16),
                                      self.UpsampleConv(self.n//32,self.n//32)])
        
        self.final_upsample_1 = self.UpsampleBlock(self.n//32,self.n//64)
        self.final_upsample_2 = self.UpsampleBlock(self.n//64,self.num_classes)[0]
        
        self.final_pooling = nn.AdaptiveMaxPool2d(output_size=self.output_size)
        
    def forward(self, z):
        num_iters = z.shape[1]//self.n
        x_list = []
        for j in range(num_iters):
            hidden = z[:,self.n*j:self.n*(j+1)]
            for i, block in enumerate(self.up_blocks):          
                hidden = block(hidden)
                hidden = self.up_conv[i](hidden)
            x = self.final_upsample_1(hidden)#.unsqueeze(0))
            del hidden
            x = self.final_upsample_2(x)
            x = self.final_pooling(x)
            x = x.view(-1,self.num_classes,self.output_size,self.output_size)
            x_list.append(x)
            del x
            #x = torch.sigmoid(x)
        return torch.stack(x_list,dim=1)

In [11]:
# lr = 5e-5
# momentum = 0.9
# num_epochs = 50
# weight_decay = 5e-4

In [12]:
#vae_decoder = Decoder(classes = 3, depth = 6, output_size=256, input_channels=512).to(device)
road_decoder = Decoder(classes = 1, depth = 6, output_size=800, input_channels=512*6).to(device)
encoder = Encoder(encoder='resnet18', pretrained = False).to(device)

In [13]:
def loss_function_road(x_hat, x,mu,logvar):
    #x_hat_sig = torch.sigmoid(x_hat)
    #DICE =  1 - (2*torch.sum(x_hat_sig*x))/(torch.sum(x*x) + torch.sum(x_hat_sig*x_hat_sig))
    #tp = (x_hat_sig * x).sum()
    #ts = tp * 1.0 / (x_hat_sig.sum() + x.sum() - tp)
    BCE = F.binary_cross_entropy_with_logits(
        x_hat, x, reduction='mean'
    )
    KLD = (0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)))/mu.numel()
    return BCE+KLD

In [14]:
# def loss_function_vae(x_hat, x, mu, logvar):
#     MSE = F.mse_loss(
#         x_hat, x, reduction='mean'
#     )
#     KLD = (0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)))/mu.numel()
#     #print(MSE,KLD)
#     return MSE+5*KLD

In [15]:
#criterion_vae = loss_function_vae
criterion_road = loss_function_road
#vae_optimizer = torch.optim.Adam(vae_decoder.parameters(),lr=5e-5)
encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=5e-5)
road_optimizer = torch.optim.Adam(road_decoder.parameters(),lr=5e-5)

In [16]:
def validation(encoder,decoder,valloader,device):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        tse_road = 0.
        total = 0.
        for i,(sample, target, road_image,extra) in enumerate(valloader):
            total+=1.0  
            input_img = sample.to(device)
            z,mu,logvar = encoder(input_img)
            predicted_road = decoder(z)[:,0,0]
            predicted_road_map = torch.sigmoid(predicted_road)>0.5
            predicted_road_map = (predicted_road_map.squeeze(1)).float()
            
            tse_road+=compute_ts_road_map(predicted_road_map, road_image.float().to(device))
            
        return (tse_road/total).item()

In [17]:
# def train(encoder, 
#           road_decoder, 
#           vae_decoder, 
#           labeled_trainloader, 
#           unlabeled_trainloader, 
#           valloader, 
#           criterion_vae, 
#           vae_optimizer, 
#           encoder_optimizer, 
#           road_optimizer,
#           criterion_road, 
#           device=device, 
#           c_point=1, 
#           max_iter=400):
#     labeled_iter = iter(labeled_trainloader)
#     unlabeled_iter = iter(unlabeled_trainloader)
#     unlabeled_training_loss = 0
#     labeled_training_loss = 0
#     threshold_val = 0.722
#     for iterator in tqdm(range(1,max_iter+1)):
        
#         vae_optimizer.zero_grad() 
#         encoder_optimizer.zero_grad() 
#         road_optimizer.zero_grad()
        
#         try:
#             unlabeled_img = next(unlabeled_iter)
#         except:
#             unlabeled_iter = iter(unlabeled_trainloader)
#             unlabeled_img = next(unlabeled_iter,None)
#         #print(unlabeled_img.shape)
#         z,mu,logvar = encoder(unlabeled_img.to(device))
#         #print(z.shape)
#         imgs_predicted = vae_decoder(z)
#         #print(imgs_predicted.shape)
#         u_loss = criterion_vae(imgs_predicted, unlabeled_img.to(device),mu,logvar)
#         unlabeled_training_loss+=u_loss.item()
        
#         l_loss = 0
#         for j in range(5):
#             try:
#                 labeled_img,_,road_image,_ = next(labeled_iter)
#             except:
#                 labeled_iter = iter(labeled_trainloader)
#                 labeled_img,_,road_image,_ = next(labeled_iter)
#             #print(labeled_img.shape)
#             labeled_img = torch.stack(labeled_img).to(device)
#             #print(z.shape)
#             z,_,_ = encoder(labeled_img)
#             road_predicted = road_decoder(z)
#             #print(road_predicted.shape)
#             l_loss+= criterion_road(road_predicted[:,0,0], torch.stack(road_image).float().to(device))
#         labeled_training_loss+=l_loss.item()

#         loss = u_loss+4*l_loss
#         loss.backward()
        
#         vae_optimizer.step() 
#         encoder_optimizer.step() 
#         road_optimizer.step()
        
#         if iterator%c_point==0:
#             val_tse_road = validation(encoder,road_decoder,valloader,device)
#             print('iterator: {}/{} | train loss labeled: {} | train loss unlabeled: {} | val ts road: {}'.format(iterator, 
#                                                                                                                  max_iter,
#                                                                                                                  round(labeled_training_loss/c_point,2), 
#                                                                                                                  round(unlabeled_training_loss/c_point,2), 
#                                                                                                                  round(val_tse_road,3)))
#             unlabeled_training_loss = 0
#             labeled_training_loss = 0
#             encoder.train()
#             road_decoder.train()
            
#             if val_tse_road>threshold_val:
#                 print('--Saving--')
#                 torch.save(encoder.state_dict(),'ss/encoder.pth')
#                 torch.save(road_decoder.state_dict(),'ss/road_decoder.pth')
#                 torch.save(vae_decoder.state_dict(),'ss/vae_decoder.pth')
#                 threshold_val = val_tse_road
                
                
                

In [18]:
# train(encoder, 
#       road_decoder, 
#       vae_decoder, 
#       labeled_trainloader, 
#       unlabeled_trainloader, 
#       labeled_valloader, 
#       criterion_vae, 
#       vae_optimizer, 
#       encoder_optimizer, 
#       road_optimizer,
#       criterion_road, 
#       device=device, 
#       c_point=20, 
#       max_iter=1000)

In [19]:
def train(encoder, 
          road_decoder, 
          labeled_trainloader, 
          valloader,
          encoder_optimizer, 
          road_optimizer,
          criterion_road, 
          device=device, 
          c_point=1, 
          max_iter=400):
    labeled_iter = iter(labeled_trainloader)
    labeled_training_loss = 0
    threshold_val = 0
    encoder.train()
    road_decoder.train()
    #encoder.load_state_dict(torch.load('da/encoder.pth'))
    #road_decoder.load_state_dict(torch.load('da/road_decoder.pth'))
    for iterator in tqdm(range(1,max_iter+1)):
        road_optimizer.zero_grad()
        encoder_optimizer.zero_grad()
        
        try:
            labeled_img,_,road_image,_ = next(labeled_iter)
        except:
            labeled_iter = iter(labeled_trainloader)
            labeled_img,_,road_image,_ = next(labeled_iter)
        #print(labeled_img.shape)
        labeled_img = torch.stack(labeled_img).to(device)
        #print(z.shape)
        z,mu,logvar = encoder(labeled_img)
        road_predicted = road_decoder(z)
        #print(road_predicted.shape)
        l_loss = criterion_road(road_predicted[:,0,0], torch.stack(road_image).float().to(device),mu,logvar)
        labeled_training_loss+=l_loss.item()

        loss = l_loss
        loss.backward()
        
        road_optimizer.step()
        encoder_optimizer.step()
        
        if iterator%c_point==0:
            val_tse_road = validation(encoder,road_decoder,valloader,device)
            print('iterator: {}/{} | train loss labeled: {} | val ts road: {}'.format(iterator, 
                                                                                     max_iter,
                                                                                     round(labeled_training_loss/c_point,2), 
                                                                                     round(val_tse_road,3)))
            
            labeled_training_loss = 0
            road_decoder.train()
            encoder.train()
            
            if val_tse_road>threshold_val:
                print('--Saving--')
                torch.save(road_decoder.state_dict(),'da/road_decoder.pth')
                torch.save(encoder.state_dict(),'da/encoder.pth')
                threshold_val = val_tse_road
                
                
                

In [None]:
train(encoder, 
          road_decoder, 
          labeled_trainloader, 
          labeled_valloader,
          encoder_optimizer, 
          road_optimizer,
          criterion_road, 
          device=device, 
          c_point=50, 
          max_iter=1000)

  5%|▍         | 49/1000 [00:12<03:51,  4.10it/s]

iterator: 50/1000 | train loss labeled: 0.94 | val ts road: 0.568
--Saving--


 10%|▉         | 99/1000 [00:49<03:40,  4.09it/s]  

iterator: 100/1000 | train loss labeled: 0.63 | val ts road: 0.706
--Saving--


 15%|█▌        | 150/1000 [01:50<1:39:04,  6.99s/it]

iterator: 150/1000 | train loss labeled: 0.57 | val ts road: 0.392


 20%|█▉        | 199/1000 [02:02<03:15,  4.10it/s]  

iterator: 200/1000 | train loss labeled: 0.55 | val ts road: 0.709
--Saving--


 25%|██▍       | 249/1000 [02:38<03:03,  4.09it/s]  

iterator: 250/1000 | train loss labeled: 0.52 | val ts road: 0.716
--Saving--


 30%|███       | 300/1000 [03:33<1:09:07,  5.92s/it]

iterator: 300/1000 | train loss labeled: 0.49 | val ts road: 0.68


 35%|███▌      | 350/1000 [04:06<1:08:20,  6.31s/it]

iterator: 350/1000 | train loss labeled: 0.52 | val ts road: 0.715


 40%|███▉      | 399/1000 [04:18<02:27,  4.08it/s]  

iterator: 400/1000 | train loss labeled: 0.44 | val ts road: 0.722
--Saving--


 45%|████▍     | 449/1000 [04:56<02:15,  4.07it/s]  

iterator: 450/1000 | train loss labeled: 0.48 | val ts road: 0.723
--Saving--


 50%|█████     | 500/1000 [05:56<58:35,  7.03s/it]  

iterator: 500/1000 | train loss labeled: 0.46 | val ts road: 0.587


 55%|█████▌    | 550/1000 [06:32<55:47,  7.44s/it]

iterator: 550/1000 | train loss labeled: 0.44 | val ts road: 0.715


 60%|█████▉    | 599/1000 [06:44<01:38,  4.07it/s]