In [1]:
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [12,12]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image

from data_helper import LabeledDataset
from helper import collate_fn, draw_box

import itertools
from scipy.spatial.distance import cdist
from helper import compute_ats_bounding_boxes, compute_ts_road_map

In [2]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

In [3]:
# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '../data'
annotation_csv = '../data/annotation.csv'

In [4]:
# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.arange(106, 134)

### Labeled Dataloader 

In [5]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [6]:
labeled_scene_index = np.arange(106, 134)
train_inds = labeled_scene_index[:23]
val_inds = labeled_scene_index[23:]

In [7]:
transform_train = transforms.Compose([torchvision.transforms.Resize((256,256)),
                                transforms.ToTensor(), 
                                transforms.Normalize([ 0.485, 0.456, 0.406 ],[ 0.229, 0.224, 0.225 ]), 
                                AddGaussianNoise(0., 0.1)])
transform_val = transforms.Compose([torchvision.transforms.Resize((256,256)),
                                transforms.ToTensor(), 
                                transforms.Normalize([ 0.485, 0.456, 0.406 ],[ 0.229, 0.224, 0.225 ])])

labeled_trainset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=train_inds,
                                  transform=transform_train,
                                  extra_info=True
                                 )
labeled_valset = LabeledDataset(image_folder=image_folder,
                                  annotation_file=annotation_csv,
                                  scene_index=val_inds,
                                  transform=transform_val,
                                  extra_info=True
                                 )
labeled_trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=16, shuffle=True, num_workers=2, collate_fn=collate_fn)
labeled_valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=1,shuffle=False, num_workers=2)

### Unlabeled Dataloader

In [8]:
NUM_SAMPLE_PER_SCENE = 126
NUM_IMAGE_PER_SAMPLE = 6
image_names = [
    'CAM_FRONT_LEFT.jpeg',
    'CAM_FRONT.jpeg',
    'CAM_FRONT_RIGHT.jpeg',
    'CAM_BACK_LEFT.jpeg',
    'CAM_BACK.jpeg',
    'CAM_BACK_RIGHT.jpeg',
    ]

def rgb_jittering(im):
    im = np.array(im, 'int32')
    for ch in range(3):
        im[:, :, ch] += np.random.randint(-2, 2)
    im[im > 255] = 255
    im[im < 0] = 0
    return im.astype('uint8')

class  JigsawDataset(torch.utils.data.Dataset):
    def __init__(self,  image_folder, scene_index, first_dim):
        
        self.image_folder = image_folder
        self.scene_index = scene_index
        self.permutations = self.get_permutations()
        
        self.__augment_tile = transforms.Compose([
            transforms.RandomCrop((256,256)),
            transforms.Lambda(rgb_jittering),
            transforms.ToTensor(),
        ])

    def __getitem__(self, index):
        scene_id = self.scene_index[index // NUM_SAMPLE_PER_SCENE]
        sample_id = index % NUM_SAMPLE_PER_SCENE
        sample_path = os.path.join(self.image_folder, f'scene_{scene_id}', f'sample_{sample_id}')

        tiles = [None] * 6
        for n in range(6):
            image_path = os.path.join(sample_path, image_names[n])
            tile = Image.open(image_path)
            tile = self.__augment_tile(tile)
            
            # Normalize the patches independently to avoid low level features shortcut
            m, s = tile.view(3, -1).mean(dim=1).numpy(), tile.view(3, -1).std(dim=1).numpy()
            s[s == 0] = 1
            norm = transforms.Normalize(mean=m.tolist(), std=s.tolist())
            tile = norm(tile)
            tiles[n] = tile
       
        order = np.random.randint(len(self.permutations))
        data = [tiles[self.permutations[order][t]] for t in range(6)]
        data = torch.stack(data, 0)

        return data, int(order)#, torch.stack(tiles)

    def __len__(self):
        return self.scene_index.size * NUM_SAMPLE_PER_SCENE
    
    def get_permutations(self, classes=100, selection="max"):
        P_hat = np.array(list(itertools.permutations(list(range(6)), 6)))
        n = P_hat.shape[0]

        for i in trange(classes):
            if i==0:
                j = np.random.randint(n)
                P = np.array(P_hat[j]).reshape([1,-1])
            else:
                P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)

            P_hat = np.delete(P_hat,j,axis=0)
            D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()

            if selection=='max':
                j = D.argmax()
            else:
                m = int(D.shape[0]/2)
                S = D.argsort()
                j = S[np.random.randint(m-10,m+10)]

        return P

In [9]:
unlabeled_trainset = JigsawDataset(image_folder=image_folder, scene_index=unlabeled_scene_index, first_dim='sample')

unlabeled_trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=16, pin_memory=True,shuffle=True, num_workers=2)

100%|██████████| 100/100 [00:00<00:00, 2777.96it/s]


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

### Model

In [11]:
class Encoder(nn.Module):
    def ConvBlock(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, use_bias = False):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
                                        stride, padding, bias = use_bias),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(True)
                             )
        return block
    
    def Bridge(self, in_channels, out_channels):
        bridge = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                               self.ConvBlock(out_channels, out_channels)
                              )
        return bridge
        
        
    def __init__(self, encoder='resnet34', pretrained = False, depth = 6):
        '''
        num_classes: Number of channels/classes for segmentation
        output_size: Final output size of the image (H*H)
        encoder: Supports resnet18, resnet 34 and resnet50 architectures
        pretrained: For loading a pretrained resnet model as encoder
        '''
        super(Encoder,self).__init__()  
        self.depth = depth        
        self.resnet = torchvision.models.resnet50(pretrained=pretrained) if encoder == "resnet50" else\
                            torchvision.models.resnet34(pretrained=pretrained) if encoder == "resnet34" else\
                            torchvision.models.resnet18(pretrained=pretrained)
        
        self.resnet_layers = list(self.resnet.children())
        self.n = 2048 if encoder == "resnet50" else 512
        
        self.input_block = nn.Sequential(*self.resnet_layers)[:3]
        #self.input_block[0] = nn.Conv2d(18, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.input_pool = self.resnet_layers[3]
        self.down_blocks = nn.ModuleList([i for i in self.resnet_layers if isinstance(i, nn.Sequential)])

        self.bridge = self.Bridge(self.n, self.n)
        
    def forward(self,x):
        B = x.shape[0]
        x = x.transpose(0,1)
        
        h_list = []
        z_list = []
        for i in range(6):
            z = self.input_pool(self.input_block(x[i]))
            for block in self.down_blocks:
                z = block(z)
            z = self.bridge(z)
            z_list.append(z)
#             h = self.pool(z)
#             h = h.view([B,1,-1])
#             h_list.append(h)
            
#         del h   
        del z
#         h = torch.cat(h_list,1)
        z = torch.cat(z_list,1)
#         del h_list
        del z_list
        return z #b,512*6,8,8

In [12]:
class JigsawNet(nn.Module):
    def __init__(self, classes = 500):
        super(JigsawNet,self).__init__()  
        self.pool = nn.AdaptiveAvgPool2d(output_size=(1,1))
        self.fc = nn.Sequential()
        self.fc.add_module('fc7',nn.Linear(6*512,512))
        self.fc.add_module('relu7',nn.ReLU(inplace=True))
        self.fc.add_module('drop7',nn.Dropout(p=0.5))

        self.classifier = nn.Sequential()
        self.classifier.add_module('fc8',nn.Linear(512, classes))
        
    def forward(self, x):
        B = x.shape[0]
        h = self.pool(x)
        x = self.fc(h.view(B,-1))
        x = self.classifier(x)
        return x

In [13]:
class Decoder(nn.Module):
    def ConvBlock(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, use_bias = False):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
                                        stride, padding, bias = use_bias),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(True)
                             )
        return block
    
    def Bridge(self, in_channels, out_channels):
        bridge = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                               self.ConvBlock(out_channels, out_channels)
                              )
        return bridge
    
    def UpsampleBlock(self, in_channels, out_channels, use_bias=False):
        upsample = nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=use_bias),
                                 nn.BatchNorm2d(out_channels),
                                 nn.ReLU(True))
        return upsample
        
    def UpsampleConv(self, in_channels, out_channels):
        upsample_conv = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                                      self.ConvBlock(out_channels, out_channels))    
        return upsample_conv
        
        
    def __init__(self, classes = 1, depth = 6, output_size=(800,800)):
        '''
        num_classes: Number of channels/classes for segmentation
        output_size: Final output size of the image (H*H)
        encoder: Supports resnet18, resnet 34 and resnet50 architectures
        pretrained: For loading a pretrained resnet model as encoder
        '''
        super(Decoder,self).__init__()  
        self.depth = depth        
        self.num_classes = classes
        self.output_size = output_size
        self.n = 512*6
        
        self.up_blocks = nn.ModuleList([self.UpsampleBlock(self.n,self.n//2)[0],
                                        self.UpsampleBlock(self.n//2,self.n//4)[0],
                                        self.UpsampleBlock(self.n//4,self.n//8)[0],
                                        self.UpsampleBlock(self.n//8,self.n//16)[0],
                                        self.UpsampleBlock(self.n//16,self.n//32)[0]])
        
        self.up_conv = nn.ModuleList([self.UpsampleConv(self.n//2,self.n//2),
                                      self.UpsampleConv(self.n//4,self.n//4),
                                      self.UpsampleConv(self.n//8,self.n//8),
                                      self.UpsampleConv(self.n//16 ,self.n//16),
                                      self.UpsampleConv(self.n//32,self.n//32)])
        
        self.final_upsample_1 = self.UpsampleBlock(self.n//32,self.n//64)
        self.final_upsample_2 = self.UpsampleBlock(self.n//64,self.num_classes)[0]
        
        self.final_pooling = nn.AdaptiveMaxPool2d(output_size=self.output_size)
        
    def forward(self, z):
        for i, block in enumerate(self.up_blocks):          
            z = block(z)
            z = self.up_conv[i](z)
        x = self.final_upsample_1(z)#.unsqueeze(0))
        del z
        x = self.final_upsample_2(x)
        x = self.final_pooling(x)
        x = x.view(-1,self.output_size,self.output_size)
        #x = torch.sigmoid(x)
        return x 

In [14]:
lr = 0.0005
momentum = 0.9
num_epochs = 50
weight_decay = 5e-4

In [15]:
jigsawnet = JigsawNet(classes=100).to(device)
encoder = Encoder(encoder='resnet34', pretrained = False).to(device)
decoder = Decoder(classes = 1, output_size=800).to(device)

In [16]:
def loss_function_road(x_hat, x):
    x_hat_sig = torch.sigmoid(x_hat)
    #DICE =  1 - (2*torch.sum(x_hat_sig*x))/(torch.sum(x*x) + torch.sum(x_hat_sig*x_hat_sig))
    tp = (x_hat_sig * x).sum()
    ts = tp * 1.0 / (x_hat_sig.sum() + x.sum() - tp)
    BCE = F.binary_cross_entropy_with_logits(
        x_hat, x, reduction='mean'
    )
    return (-0.3*torch.log(ts)+0.7*BCE)

In [17]:
criterion_jigsaw = nn.CrossEntropyLoss()
criterion_road = loss_function_road
jigsaw_optimizer = torch.optim.SGD(jigsawnet.parameters(),lr=lr,momentum=momentum,weight_decay = weight_decay)
encoder_optimizer = torch.optim.Adam(encoder.parameters(),lr=5e-5)
decoder_optimizer = torch.optim.Adam(decoder.parameters(),lr=5e-5)

In [18]:
def validation(encoder,decoder,valloader,device):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        tse_road = 0.
        total = 0.
        for i,(sample, target, road_image,extra) in enumerate(valloader):
            total+=1.0  
            input_img = sample.to(device)
            predicted_road = decoder(encoder(input_img))
            predicted_road_map = torch.sigmoid(predicted_road)>0.5
            predicted_road_map = (predicted_road_map.squeeze(1)).float()
            
            tse_road+=compute_ts_road_map(predicted_road_map, road_image.float().to(device))
            
        return (tse_road/total).item()

In [19]:
def train(encoder, 
          decoder, 
          jigsawnet, 
          labeled_trainloader, 
          unlabeled_trainloader, 
          valloader, 
          criterion_jigsaw, 
          jigsaw_optimizer, 
          encoder_optimizer, 
          decoder_optimizer,
          criterion_road, 
          device=device, 
          c_point=1, 
          max_iter=400):
    labeled_iter = iter(labeled_trainloader)
    unlabeled_iter = iter(unlabeled_trainloader)
    unlabeled_training_loss = 0
    labeled_training_loss = 0
    threshold_val = 0
    for iterator in tqdm(range(1,max_iter+1)):
        
        jigsaw_optimizer.zero_grad() 
        encoder_optimizer.zero_grad() 
        decoder_optimizer.zero_grad()
        
        try:
            unlabeled_img,permutation_target = next(unlabeled_iter)
        except:
            unlabeled_iter = iter(unlabeled_trainloader)
            unlabeled_img,permutation_target = next(unlabeled_iter,None)
        permutation_predicted = jigsawnet(encoder(unlabeled_img.to(device)))
        u_loss = criterion_jigsaw(permutation_predicted, permutation_target.to(device))
        unlabeled_training_loss+=u_loss.item()
        
        try:
            labeled_img,_,road_image,_ = next(labeled_iter)
        except:
            labeled_iter = iter(labeled_trainloader)
            labeled_img,_,road_image,_ = next(labeled_iter)
            
        labeled_img = torch.stack(labeled_img).to(device)
        road_predicted = decoder(encoder(labeled_img))
        l_loss = criterion_road(road_predicted, torch.stack(road_image).float().to(device))
        labeled_training_loss+=l_loss.item()
        
        loss = u_loss+5*l_loss
        loss.backward()
        
        jigsaw_optimizer.step() 
        encoder_optimizer.step() 
        decoder_optimizer.step()
        
        if iterator%c_point==0:
            val_tse_road = validation(encoder,decoder,valloader,device)
            print('iterator: {}/{} | train loss labeled: {} | train loss unlabeled: {} | val ts road: {}'.format(iterator, 
                                                                                                                 max_iter,
                                                                                                                 round(labeled_training_loss/c_point,2), 
                                                                                                                 round(unlabeled_training_loss/c_point,2), 
                                                                                                                 round(val_tse_road,3)))
            unlabeled_training_loss = 0
            labeled_training_loss = 0
            encoder.train()
            decoder.train()
            
            if val_tse_road>threshold_val:
                print('--Saving--')
                torch.save(encoder.state_dict(),'ss/encoder.pth')
                torch.save(decoder.state_dict(),'ss/decoder.pth')
                torch.save(jigsawnet.state_dict(),'ss/jigsaw.pth')
                threshold_val = val_tse_road
                
                
                

In [None]:
train(encoder, 
          decoder, 
          jigsawnet, 
          labeled_trainloader, 
          unlabeled_trainloader, 
          labeled_valloader, 
          criterion_jigsaw, 
          jigsaw_optimizer, 
          encoder_optimizer, 
          decoder_optimizer,
          criterion_road, 
          device=device, 
          c_point=10, 
          max_iter=4000)

  0%|          | 9/4000 [00:16<1:56:20,  1.75s/it]

iterator: 10/4000 | train loss labeled: 0.84 | train loss unlabeled: 4.61 | val ts road: 0.488
--Saving--


  0%|          | 20/4000 [02:15<19:05:10, 17.26s/it]

iterator: 20/4000 | train loss labeled: 0.64 | train loss unlabeled: 4.6 | val ts road: 0.45


  1%|          | 29/4000 [02:30<2:31:28,  2.29s/it] 

iterator: 30/4000 | train loss labeled: 0.56 | train loss unlabeled: 4.6 | val ts road: 0.697
--Saving--


  1%|          | 40/4000 [04:30<19:20:35, 17.58s/it]

iterator: 40/4000 | train loss labeled: 0.51 | train loss unlabeled: 4.62 | val ts road: 0.684


  1%|▏         | 50/4000 [05:35<18:15:13, 16.64s/it]

iterator: 50/4000 | train loss labeled: 0.49 | train loss unlabeled: 4.58 | val ts road: 0.691


  2%|▏         | 60/4000 [06:42<18:41:14, 17.07s/it]

iterator: 60/4000 | train loss labeled: 0.45 | train loss unlabeled: 4.59 | val ts road: 0.694


  2%|▏         | 70/4000 [07:47<18:07:03, 16.60s/it]

iterator: 70/4000 | train loss labeled: 0.44 | train loss unlabeled: 4.63 | val ts road: 0.692


  2%|▏         | 80/4000 [08:54<18:45:31, 17.23s/it]

iterator: 80/4000 | train loss labeled: 0.44 | train loss unlabeled: 4.57 | val ts road: 0.692


  2%|▏         | 90/4000 [09:59<18:12:17, 16.76s/it]

iterator: 90/4000 | train loss labeled: 0.42 | train loss unlabeled: 4.58 | val ts road: 0.668


  2%|▎         | 100/4000 [11:05<18:18:24, 16.90s/it]

iterator: 100/4000 | train loss labeled: 0.42 | train loss unlabeled: 4.61 | val ts road: 0.669


  3%|▎         | 110/4000 [12:12<18:37:32, 17.24s/it]

iterator: 110/4000 | train loss labeled: 0.39 | train loss unlabeled: 4.6 | val ts road: 0.696


  3%|▎         | 120/4000 [13:18<18:06:10, 16.80s/it]

iterator: 120/4000 | train loss labeled: 0.42 | train loss unlabeled: 4.57 | val ts road: 0.669


  3%|▎         | 130/4000 [14:24<18:12:14, 16.93s/it]

iterator: 130/4000 | train loss labeled: 0.41 | train loss unlabeled: 4.54 | val ts road: 0.676


  3%|▎         | 137/4000 [14:36<3:08:07,  2.92s/it] 

In [None]:
# best_val_loss = 1000000
# for epoch in range(num_epochs):
#     train_loss = 0
#     model.train()
#     for i, (sample, target) in enumerate(tqdm(trainloader)):        
#         sample, target = sample.to(device), target.to(device)
        
#         optimizer.zero_grad()
#         out = model(sample)

#         loss = criterion(out, target)
#         loss.backward()
#         optimizer.step()
        
#         train_loss += loss.item()
#         if (i+1)%150 == 0:
#             print("Epoch: {} | Iter: {} | Train loss: {}".format(epoch+1, i+1, train_loss/(i+1)))
    
#     model.eval()
#     val_loss = 0
#     with torch.no_grad():        
#         for i, (sample,target) in enumerate(tqdm(valloader)):
#             sample, target = sample.to(device),target.to(device)
#             out = model(sample)
#             loss = criterion(out, target)
#             val_loss += loss.item()
    
#     epoch_val_loss = val_loss/len(valloader)
#     print("Epoch: {} | Val loss: {}".format(epoch+1,epoch_val_loss))
#     if epoch_val_loss<best_val_loss:
#         best_val_loss = epoch_val_loss
#         print("Saving model...")
#         torch.save(model.state_dict(),'jigsaw_task_001.pth')