# One Shot Learning with Siamese Networks

This is the jupyter notebook that accompanies

## Imports
All the imports are defined here

In [2]:
%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torchsummary import summary
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchsummary import summary
import img_to_vec

## Helper functions
Set of helper functions

In [None]:
def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

## Configuration Class
A simple class to manage configuration

In [3]:
class Config():
    training_dir = "data/train_2/"
    testing_dir = "data/test/"
    custom_test_dir = "data/custom_test/"
    train_batch_size = 32
    train_number_epochs = 25

In [4]:
folder_dataset = dset.ImageFolder(root=Config.custom_test_dir)
folder_dataset.imgs

[('data/custom_test/0\\BAG4551_T00238_2009_12_15_090191f680fef1a1-1.Jpeg', 0),
 ('data/custom_test/1\\BAH2014_T00238_2016_01_29_090191f691c95307-1.Jpeg', 1)]

In [5]:
model = models.resnet18(pretrained=True)

In [6]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [8]:
layer = model._modules.get('layer3')
model.cuda()
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [9]:
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [None]:
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

In [None]:
def get_vector(image_name, is_path=True):
    
    if is_path:
        img = Image.open(image_name)
        t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0)).cuda()
    else:
        img = image_name
        t_img = Variable(normalize(img).unsqueeze(0), dt).cuda()
    
    image_embedding = []
    
    def copy_data(m, i, o):
        image_embedding.append(o.data)
    h = layer.register_forward_hook(copy_data)
    model(t_img)
    h.remove()
    return image_embedding

## Custom Dataset Class
This dataset generates a pair of images. 0 for geniune pair and 1 for imposter pair

In [None]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self,imageFolderDataset,transform=None,should_invert=True,
                 is_test=False,pick_similar_samples=True,
                is_custom_test=False):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        self.is_test = is_test
        self.pick_similar_samples = pick_similar_samples
        self.is_custom_test = is_custom_test
        
    def __getitem__(self,index):
        if not self.is_custom_test:
            img0_tuple = random.choice(self.imageFolderDataset.imgs)
            #we need to make sure approx 50% of images are in the same class
            if not self.is_test:
                should_get_same_class = random.randint(0,1)
            else:
                if self.pick_similar_samples : should_get_same_class = 1
                else: should_get_same_class = 0

            if should_get_same_class:
                while True:
                    #keep looping till the same class image is found
                    img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                    if img0_tuple[1]==img1_tuple[1]:
                        break
            else:
                while True:
                    #keep looping till a different class image is found

                    img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                    if img0_tuple[1] !=img1_tuple[1]:
                        break
        else:
            img0_tuple = (self.imageFolderDataset.imgs[0])
            img1_tuple = (self.imageFolderDataset.imgs[1])
        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
        img0 = img0.convert("L")
        img1 = img1.convert("L")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
#         print("ïmg 0 shape: {0}".format(img0.shape))
        img0_vec = get_vector(img0_tuple[0], is_path=True)[0]
        img1_vec = get_vector(img1_tuple[0], is_path=True)[0]
        img0_vec.squeeze_()
        img1_vec.squeeze_()
#         print("ïmg 0 shape: {0}".format(img0_vec.shape))
        return img0, img1, img0_vec, img1_vec , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

## Using Image Folder Dataset

In [None]:
folder_dataset = dset.ImageFolder(root=Config.training_dir)
folder_dataset

In [None]:
# transforms.RandomAffine([5,7,10,13,15], translate=(5,15)),
#transforms.RandomHorizontalFlip(p=0.5),

In [None]:
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
                                        transform=transforms.Compose([transforms.Resize((224,224)),
                                                                      transforms.ColorJitter(brightness=0.1,contrast=0.1),
                                                                      transforms.Grayscale(num_output_channels=1),
                                                                      transforms.RandomRotation([0,75]),
                                                                      transforms.RandomAffine([0,20], translate=(0.1, 0.95), scale=(0.5,2), shear=5),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False)

## Visualising some of the data
The top row and the bottom row of any column is one pair. The 0s and 1s correspond to the column of the image.
1 indiciates dissimilar, and 0 indicates similar.

In [None]:
# vis_dataloader = DataLoader(siamese_dataset,
#                         shuffle=True,
#                         num_workers=0,
#                         batch_size=8)
# dataiter = iter(vis_dataloader)


# example_batch = next(dataiter)
# # print(example_batch[0].shape)
# concatenated = torch.cat((example_batch[0],example_batch[1]),0)
# print(concatenated.size())
# #imshow(torchvision.utils.make_grid(concatenated))
# print(example_batch[2].numpy())

## Neural Net Definition
We will use a standard convolutional neural network

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(128, 32, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(32, 16, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(16),


            nn.ReflectionPad2d(1),
            nn.Conv2d(16, 16, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(16),


        )

        self.fc1 = nn.Sequential(
            nn.Linear(16*28*28, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5))

    def forward_once(self, x):
        output = self.cnn1(x)
#         print("1: {0}".format(output.shape))
        output = output.view(output.size()[0], -1)
#         print("2: {0}".format(output.shape))
        output = self.fc1(output)
#         print("3: {0}".format(output.shape))
        return output

    def forward(self, input1, input2):
        #print(input1)
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

## Contrastive Loss

In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

## Training Time!

In [None]:
train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=0,
                        batch_size=Config.train_batch_size)

In [None]:
net = SiameseNetwork()
net.cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )

In [None]:
from torchsummary import summary
# summary(net,(128,224,224),device="cuda")

In [None]:
counter = []
loss_history = [] 
iteration_number= 0

In [None]:
### if __name__ == "__main__":
for epoch in range(0,Config.train_number_epochs):
    for i, data in enumerate(train_dataloader,0):
        _, __, img0, img1 , label = data
        img0, img1 , label = img0.cuda(), img1.cuda() , label.cuda()
#             print("img0: ",img0.is_cuda)
#             print("img1: ",img1.is_cuda)
#             print("model: ",next(net.parameters()).is_cuda)
#         print("In traing loop: {0}".format(img0.shape))
        optimizer.zero_grad()
        output1,output2 = net(img0,img1)
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())
show_plot(counter,loss_history)

## Some simple testing
The last 3 subjects were held out from the training, and will be used to test. The Distance between each image pair denotes the degree of similarity the model found between the two images. Less means it found more similar, while higher values indicate it found them to be dissimilar.

In [None]:
net = torch.load("model")
net.eval()

In [None]:
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                        transform=transforms.Compose([transforms.Resize((229,229)),
                                                                      
                                                                      transforms.Grayscale(num_output_channels=1),
                                                                      transforms.RandomRotation([5,75]),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False,
                                        is_test=True,
                                       pick_similar_samples=True)

test_dataloader = DataLoader(siamese_dataset,num_workers=0,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
#x0,_,_ = next(dataiter)

for i in range(10):
    _x0,_x1,x0,x1,label2 = next(dataiter)
    concatenated = torch.cat((_x0,_x1),0)
    
    output1,output2 = net(Variable(x0).cuda(),Variable(x1).cuda())
    euclidean_distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))



In [None]:
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                        transform=transforms.Compose([transforms.Resize((229,229)),
                                                                      
                                                                      transforms.Grayscale(num_output_channels=1),
                                                                      transforms.RandomRotation([5,75]),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False,
                                        is_test=True,
                                       pick_similar_samples=False)

test_dataloader = DataLoader(siamese_dataset,num_workers=0,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
#x0,_,_ = next(dataiter)

for i in range(10):
    _x0,_x1,x0,x1,label2 = next(dataiter)
    concatenated = torch.cat((_x0,_x1),0)
    
    output1,output2 = net(Variable(x0).cuda(),Variable(x1).cuda())
    euclidean_distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))


In [None]:
folder_dataset_test = dset.ImageFolder(root=Config.custom_test_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                        transform=transforms.Compose([transforms.Resize((224,224)),
                                                                      
                                                                      transforms.Grayscale(num_output_channels=1),
                                                                      transforms.RandomRotation([5,75]),
                                                                      transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False,
                                        is_test=True,
                                       pick_similar_samples=False,
                                       is_custom_test=True)

test_dataloader = DataLoader(siamese_dataset,num_workers=0,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
#x0,_,_ = next(dataiter)

for i in range(1):
    _x0,_x1,x0,x1,label2 = next(dataiter)
    concatenated = torch.cat((_x0,_x1),0)
    
    output1,output2 = net(Variable(x0).cuda(),Variable(x1).cuda())
    euclidean_distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))


In [None]:
torch.save(net, "model")

In [None]:
torch.load("model")

In [None]:
import time

In [None]:
start = time.time()

In [None]:
end = time.time()-start

In [None]:
end

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device