# One Shot Learning with Siamese Networks

This is the jupyter notebook that accompanies

## Imports
All the imports are defined here

In [1]:
%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import functional as TF
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
import os
from tqdm import tqdm_notebook as tqdm
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from res_unet import SiameseNetwork

## Helper functions
Set of helper functions

In [None]:
def imshow(img1, img2):
    
    img1 = TF.to_pil_image(img1, mode='L')
    img2 = TF.to_pil_image(img2, mode='L')
    
    plt.subplot(121)
    plt.imshow(img1)
    
    plt.subplot(122)
    plt.imshow(img2)
    
    plt.show()    
    
def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.grid()
    plt.show()

## Configuration Class
A simple class to manage configuration

In [None]:
class Config():
    training_dir = "./data_madori/train/"
    testing_dir = "./data_madori/test/"
    train_batch_size = 16
    train_number_epochs = 100

## Custom Dataset Class
This dataset generates a pair of images. 0 for geniune pair and 1 for imposter pair

In [None]:
class MadoriSiameseDataset(Dataset):
    def __init__(self, data_dir, img_size=(256, 256)):
        self.img_paths = [os.path.join(Config.training_dir, x) for x in os.listdir(data_dir)]
        self.img_size = (256, 256)
        
    def __len__(self):
        return len(self.img_paths)
    
    def _resize(self, img):
        w, h = img.size
        if w < h:
            a = 256.0 / h
            b = int(w * a)
            img = img.resize((b, 256), Image.BILINEAR)
        else:
            a = 256.0 / w
            b = int(h * a)
            img = img.resize((256, b), Image.BILINEAR)
        return img
    
    def _pad(self, img):
        w, h = img.size
        img = TF.pad(img, (0,0,256-w,0), padding_mode='edge') if h == 256 else \
               TF.pad(img, (0,0,0,256-h), padding_mode='edge')
        
        return img
    
    def _transform(self, img):
        return self._pad(self._resize(img))
    
    def _aug_img(self, image):
        if random.random() > 0.5:
            image = TF.rotate(image, random.choice([90, 180, 270]))
        if random.random() > 0.5:
            image = TF.hflip(image)
        if random.random() > 0.5:
            image = TF.vflip(image)
        return image
    
    def __getitem__(self, idx):
        img_path1 = self.img_paths[idx]
        img1 = self._transform(Image.open(img_path1).convert('L'))
        label = random.randint(0, 1)
        if label:
            # choose different floorplan
            img_path2 = img_path1
            while img_path2 == img_path1:
                img_path2 = random.choice(self.img_paths)
            img2 = self._transform(Image.open(img_path2).convert('L'))
        else:
            # choose similar floorplan by augmentation
            img2 = self._aug_img(img1)
        img1, img2 = TF.to_tensor(img1), TF.to_tensor(img2)
        return img1, img2, torch.from_numpy(np.array([label],dtype=np.float32))

## Visualising some of the data
The top row and the bottom row of any column is one pair. The 0s and 1s correspond to the column of the image.
1 indiciates dissimilar, and 0 indicates similar.

In [None]:
dset_loader = DataLoader(MadoriDataset(Config.training_dir), 
                           batch_size=Config.train_batch_size, 
                           shuffle=True)

for i, batch in enumerate(dset_loader):
    img1, img2, label = batch
    print(img1.size(), img2.size())
    for k in range(3):
        print(label[k])
        imshow(img1[k], img2[k])
    break

## Contrastive Loss

In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - 
                                        euclidean_distance, min=0.0), 2))
        return loss_contrastive

## Training Time!

In [None]:
train_dataloader = DataLoader(MadoriDataset(Config.training_dir), 
                                shuffle=True,
                                batch_size=Config.train_batch_size)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
net = SiameseNetwork().to(device)
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )

In [None]:
loss_history = [] 

In [None]:
for epoch in tqdm(range(Config.train_number_epochs)):
    for i, data in enumerate(train_dataloader):
        img0, img1, label = data
        img0, img1, label = img0.to(device), img1.to(device) , label.to(device)
        optimizer.zero_grad()
        output1,output2 = net(img0,img1)
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
    print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
    loss_history.append(loss_contrastive.item())
#show_plot(counter,loss_history)

In [None]:
show_plot(counter,loss_history)

## Some simple testing
The last 3 subjects were held out from the training, and will be used to test. The Distance between each image pair denotes the degree of similarity the model found between the two images. Less means it found more similar, while higher values indicate it found them to be dissimilar.

In [None]:
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                        transform=transforms.Compose([transforms.Resize((100,100)),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False)

test_dataloader = DataLoader(siamese_dataset,num_workers=6,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
x0,_,_ = next(dataiter)

for i in range(10):
    _,x1,label2 = next(dataiter)
    concatenated = torch.cat((x0,x1),0)
    
    output1,output2 = net(Variable(x0).to(device),Variable(x1).to(device))
    euclidean_distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))

