In [None]:
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [12,12]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image

from data_helper import LabeledDataset
from helper import collate_fn, draw_box

import itertools
from scipy.spatial.distance import cdist

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

In [None]:
# All the images are saved in image_folder
# All the labels are saved in the annotation_csv file
image_folder = '../data'
annotation_csv = '../data/annotation.csv'

In [None]:
# You shouldn't change the unlabeled_scene_index
# The first 106 scenes are unlabeled
unlabeled_scene_index = np.arange(106)
# The scenes from 106 - 133 are labeled
# You should devide the labeled_scene_index into two subsets (training and validation)
labeled_scene_index = np.arange(106, 134)

train_split=int(106*0.8)
train_inds = unlabeled_scene_index[:train_split]
val_inds = unlabeled_scene_index[train_split:]

### Dataloader

In [None]:
NUM_SAMPLE_PER_SCENE = 126
NUM_IMAGE_PER_SAMPLE = 6
image_names = [
    'CAM_FRONT_LEFT.jpeg',
    'CAM_FRONT.jpeg',
    'CAM_FRONT_RIGHT.jpeg',
    'CAM_BACK_LEFT.jpeg',
    'CAM_BACK.jpeg',
    'CAM_BACK_RIGHT.jpeg',
    ]

def rgb_jittering(im):
    im = np.array(im, 'int32')
    for ch in range(3):
        im[:, :, ch] += np.random.randint(-2, 2)
    im[im > 255] = 255
    im[im < 0] = 0
    return im.astype('uint8')

class  JigsawDataset(torch.utils.data.Dataset):
    def __init__(self,  image_folder, scene_index, first_dim):
        
        self.image_folder = image_folder
        self.scene_index = scene_index
        self.permutations = self.get_permutations()
        
        self.__augment_tile = transforms.Compose([
            transforms.RandomCrop((256,256)),
            transforms.Lambda(rgb_jittering),
            transforms.ToTensor(),
        ])

    def __getitem__(self, index):
        scene_id = self.scene_index[index // NUM_SAMPLE_PER_SCENE]
        sample_id = index % NUM_SAMPLE_PER_SCENE
        sample_path = os.path.join(self.image_folder, f'scene_{scene_id}', f'sample_{sample_id}')

        tiles = [None] * 6
        for n in range(6):
            image_path = os.path.join(sample_path, image_names[n])
            tile = Image.open(image_path)
            tile = self.__augment_tile(tile)
            
            # Normalize the patches independently to avoid low level features shortcut
            m, s = tile.view(3, -1).mean(dim=1).numpy(), tile.view(3, -1).std(dim=1).numpy()
            s[s == 0] = 1
            norm = transforms.Normalize(mean=m.tolist(), std=s.tolist())
            tile = norm(tile)
            tiles[n] = tile
       
        order = np.random.randint(len(self.permutations))
        data = [tiles[self.permutations[order][t]] for t in range(6)]
        data = torch.stack(data, 0)

        return data, int(order)#, torch.stack(tiles)

    def __len__(self):
        return self.scene_index.size * NUM_SAMPLE_PER_SCENE
    
    def get_permutations(self, classes=500, selection="max"):
        P_hat = np.array(list(itertools.permutations(list(range(6)), 6)))
        n = P_hat.shape[0]

        for i in trange(classes):
            if i==0:
                j = np.random.randint(n)
                P = np.array(P_hat[j]).reshape([1,-1])
            else:
                P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)

            P_hat = np.delete(P_hat,j,axis=0)
            D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()

            if selection=='max':
                j = D.argmax()
            else:
                m = int(D.shape[0]/2)
                S = D.argsort()
                j = S[np.random.randint(m-10,m+10)]

        return P

In [None]:
unlabeled_trainset = JigsawDataset(image_folder=image_folder, scene_index=train_inds, first_dim='sample')
unlabeled_valset = JigsawDataset(image_folder=image_folder, scene_index=val_inds, first_dim='sample')

trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=32, pin_memory=True,shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(unlabeled_valset, batch_size=1, pin_memory=True,shuffle=True, num_workers=2)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

### Model

In [None]:
class JigsawNet(nn.Module):
    def ConvBlock(self, in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, use_bias = False):
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
                                        stride, padding, bias = use_bias),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(True)
                             )
        return block
    
    def Bridge(self, in_channels, out_channels):
        bridge = nn.Sequential(self.ConvBlock(in_channels, out_channels),
                               self.ConvBlock(out_channels, out_channels)
                              )
        return bridge
        
        
    def __init__(self, classes = 500, encoder='resnet34', pretrained = False, depth = 6):
        '''
        num_classes: Number of channels/classes for segmentation
        output_size: Final output size of the image (H*H)
        encoder: Supports resnet18, resnet 34 and resnet50 architectures
        pretrained: For loading a pretrained resnet model as encoder
        '''
        super(JigsawNet,self).__init__()  
        self.depth = depth        
        self.resnet = torchvision.models.resnet50(pretrained=pretrained) if encoder == "resnet50" else\
                            torchvision.models.resnet34(pretrained=pretrained) if encoder == "resnet34" else\
                            torchvision.models.resnet18(pretrained=pretrained)
        
        self.resnet_layers = list(self.resnet.children())
        self.n = 2048 if encoder == "resnet50" else 512
        
        self.input_block = nn.Sequential(*self.resnet_layers)[:3]
        #self.input_block[0] = nn.Conv2d(18, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.input_pool = self.resnet_layers[3]
        self.down_blocks = nn.ModuleList([i for i in self.resnet_layers if isinstance(i, nn.Sequential)])

        self.bridge = self.Bridge(self.n, self.n)
        self.down_channels = nn.Conv2d(512,32,kernel_size=1,stride=1)
        self.fc6 = nn.Sequential()
        self.fc6.add_module('fc6_s1',nn.Linear(32*8*8, 1024))
        self.fc6.add_module('relu6_s1',nn.ReLU(inplace=True))
        self.fc6.add_module('drop6_s1',nn.Dropout(p=0.5))

        self.fc7 = nn.Sequential()
        self.fc7.add_module('fc7',nn.Linear(6*1024,4096))
        self.fc7.add_module('relu7',nn.ReLU(inplace=True))
        self.fc7.add_module('drop7',nn.Dropout(p=0.5))

        self.classifier = nn.Sequential()
        self.classifier.add_module('fc8',nn.Linear(4096, classes))
        
    def forward(self, x):
        B = x.shape[0]
        x = x.transpose(0,1)
        
        x_list = []
        for i in range(6):
            z = self.input_pool(self.input_block(x[i]))
            for block in self.down_blocks:
                z = block(z)
            z = self.down_channels(z)
            z = self.fc6(z.view(B,-1))
            z = z.view([B,1,-1])
            x_list.append(z)
            
        del z            
        x = torch.cat(x_list,1)
        del x_list
        x = self.fc7(x.view(B,-1))
        x = self.classifier(x)
        return x

In [None]:
lr = 0.0005
momentum = 0.9
num_epochs = 50
weight_decay = 5e-4

In [None]:
model = JigsawNet(classes=500, encoder="resnet34", pretrained = False).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum,weight_decay = weight_decay)

In [None]:
best_val_loss = 1000000
for epoch in range(num_epochs):
    train_loss = 0
    model.train()
    for i, (sample, target) in enumerate(tqdm(trainloader)):        
        sample, target = sample.to(device), target.to(device)
        
        optimizer.zero_grad()
        out = model(sample)

        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        if (i+1)%150 == 0:
            print("Epoch: {} | Iter: {} | Train loss: {}".format(epoch+1, i+1, train_loss/(i+1)))
    
    model.eval()
    val_loss = 0
    with torch.no_grad():        
        for i, (sample,target) in enumerate(tqdm(valloader)):
            sample, target = sample.to(device),target.to(device)
            out = model(sample)
            loss = criterion(out, target)
            val_loss += loss.item()
    
    epoch_val_loss = val_loss/len(valloader)
    print("Epoch: {} | Val loss: {}".format(epoch+1,epoch_val_loss))
    if epoch_val_loss<best_val_loss:
        best_val_loss = epoch_val_loss
        print("Saving model...")
        torch.save(model.state_dict(),'jigsaw_task_001.pth')