# One Shot Learning with Siamese Networks

[Modified from Harshvardhan Gupta's work] (https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ipynb)  

[for the Kaggle whale tail competition]
(https://www.kaggle.com/c/humpback-whale-identification)

In [None]:
%matplotlib inline
import os
import glob
import random
import numpy as np
import PIL.ImageOps    
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm


import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

import torch
import torchvision
import torchvision.utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset


## Helper functions
Set of helper functions

In [None]:
def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [None]:
#Path to your whale data
pather = '/home/eagle/whale/data/'

In [None]:
train_csv = pd.read_csv(f'{pather}train.csv')
classes=[str(i) for i in (set(train_csv['Id']))]
samp = pd.read_csv(f'{pather}sample_submission.csv')
sub = samp.copy()

In [None]:
#Returns dict of enumerated classes
def find_classes(classes):
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

def make_dataset(train, csv_file, class_to_idx):
    images = []
    trn_csv = pd.read_csv(csv_file)
    for target in range(len(trn_csv)):
        patho = os.path.join(train, trn_csv.loc[target][0])
        lip = (patho, class_to_idx[trn_csv.loc[target][1]])
        images.append(lip)
    return images

In [None]:
#Creates a dataset with the files
class Whale_csv_Dataset(Dataset):
    def __init__(self, csv_file, train, classes, test_path=None, transform=None):

        classes, class_to_idx = find_classes(classes)
        imgs = make_dataset(train, csv_file, class_to_idx)
        test_path = sorted(glob.glob(f'{test_path}/*'))

        self.imgs = imgs
        self.class_to_idx = class_to_idx

        self.test_path = test_path
        self.whale_frame = pd.read_csv(csv_file)
        self.train = train
        self.classes = classes
        self.transform = transform

    def __len__(self):
        return len(self.whale_frame)

    def __getitem__(self, idx):
        path, target = self.imgs[idx]
        img_name = os.path.join(self.train,
                                self.whale_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.whale_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [None]:
Trywhale = Whale_csv_Dataset(csv_file=f'{pather}train.csv',
                                train=f'{pather}train/',
                                classes=classes,
                                test_path=f'{pather}test/')

In [None]:
#Creates a Pandas Dataframe of the set NOT including class 0 ('new_whale')
df = pd.DataFrame(Trywhale.imgs)
for i in range(len(df)):
    if df.loc[i][1] == 0:
        df.drop(i, axis=0, inplace=True)
df.reset_index(inplace=True, drop=True)
label_to_indices = {label: np.where(np.array(df[1]) == label)[0]
                  for label in set(df[1])}

In [None]:
#Selects images for the Siamese network
def grab_two(df):
    #Select a class
    dex = (random.choice(range(len(label_to_indices))))
    #If the class is 'class 0' we don't want it 
    if dex == 0:
        dex += 1
    one_class = label_to_indices[dex]
    #Some classes have only one image
    if len(one_class) > 1:
        #If there is more than one image in the class we select an image
        w0 = np.random.choice(one_class)
        other_class = np.random.choice(2)
        if other_class > 1:
            #If > 1 - 1/2 of the time we'll select a second image of a different class
            w1 = np.random.choice(df.index)
        else:
            #1/2 of the time we'll select an image in the same class
            w1 = w0
            #Use 'while it's not itself' to force it to not grab a duplicate
            while w0 == w1:
                w1 = np.random.choice(one_class)
    else:
        #if the class contains only one photo
        w0 = np.random.choice(one_class)
        w1 = np.random.choice(df.index)

    return df.loc[w0][0], df.loc[w1][0], df.loc[w0][1], df.loc[w1][1]

In [None]:
class SiameseNetworkDataset(Dataset):
    def __init__(self,pic_set,transform=None,should_invert=True, show_mode=False):
        self.pic_set = pic_set    
        self.transform = transform
        self.should_invert = should_invert
        self.show_mode = show_mode
        
    def __getitem__(self, index):
        #Show_mode will be used later for predictions. (Returns one image and label)
        if self.show_mode:
            img0, label_0 = df.loc[index]
            img1, label_1 = img0, label_0
        else:
            img0, img1, label_0, label_1 = grab_two(df)

        img0 = Image.open(img0)
        img1 = Image.open(img1)
        img0 = img0.convert("RGB")
        img1 = img1.convert("RGB")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
    
        return img0, img1 , torch.from_numpy(
            np.array([int(label_0!=label_1)],dtype=np.float32)), label_0
    
    def __len__(self):
        return len(df)

## Visualising some of the data
The top row and the bottom row of any column is one pair. The 0s and 1s correspond to the column of the image.
1 indiciates dissimilar, and 0 indicates similar.

In [None]:
siamese_dataset = SiameseNetworkDataset(Trywhale,
                                        transform=transforms.Compose([transforms.Resize((170,300)),
                                                                     transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False)

In [None]:
vis_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=1)
dataiter = iter(vis_dataloader)


example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())

## Neural Net Definition
We will use a standard convolutional neural network

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(3, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8))

        self.fc1 = nn.Sequential(
            nn.Linear(8*170*300, 700),
            nn.ReLU(inplace=True),
            nn.Linear(700, 700),
            nn.ReLU(inplace=True),
            nn.Linear(700, 500),
            nn.ReLU(inplace=True),
            nn.Linear(500, 1))

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2
    
    def get_embedding(self, x):
        return self.forward_once(x)

## Contrastive Loss

In [None]:
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=1):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                    (label) * torch.pow(
                                    torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

## Training Time!

In [None]:
train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        num_workers=4,
                        batch_size=4)

In [None]:
net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.001 )

In [None]:
counter = []
loss_history = [] 
iteration_number = 0

In [None]:
for epoch in range(1):
    for i, data in enumerate(train_dataloader):
        img0, img1, label, _ = data
        img0, img1, label = img0.cuda(), img1.cuda() , label.cuda()
        optimizer.zero_grad()
        output1,output2 = net(img0,img1)
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        if i %25 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(i,loss_contrastive.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())
        if i > 5000:
            break
show_plot(counter,loss_history)

## Prepare the Test Images for the model to predict

In [None]:
class SiameseNet_TEST_set(Dataset):
    def __init__(self,pic_set,transform=None,should_invert=False):
        self.pic_set = pic_set
        self.transform = transform
        self.should_invert = should_invert
        
    def __getitem__(self,index):
        img0 = self.pic_set.test_path[index]
        
        img0 = Image.open(img0)
        img0 = img0.convert("RGB")        
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)

        if self.transform is not None:
            img0 = self.transform(img0)

        return img0, index
    
    def __len__(self):
        return len(self.pic_set.test_path)

In [None]:
#Loads the training images into the model just to get average class embeddings
siamese_dataset = SiameseNetworkDataset(Trywhale,
                                        transform=transforms.Compose([transforms.Resize((170,300)),
                                                                     transforms.ToTensor()]),
                                       should_invert=False,
                                       show_mode=True)

In [None]:
siamese_testset = SiameseNet_TEST_set(Trywhale,
                                        transform=transforms.Compose([transforms.Resize((170,300)),
                                                                     transforms.ToTensor()
                                                                      ])
                                       ,should_invert=False)

In [None]:
trained_dataloader = DataLoader(siamese_dataset,num_workers=6,batch_size=1,shuffle=False)
dataiter = iter(trained_dataloader)

In [None]:
with torch.no_grad():
    net.eval()
    df_tr = pd.DataFrame([1, 2]).T
    k = 0
    for i, _,_, x in (tqdm(dataiter)):
        i = i.cuda()
        df_tr.loc[k] = float(net.get_embedding(i).data.cpu().numpy()), int(x)
        k += 1

In [None]:
#Here we have each < 0 class image by index number, embedding (0)
df_tr.head()

In [None]:
#Associates every embedding with it's class in a dict 
s = df_tr.groupby(1)[0].apply(lambda x: x.tolist()).to_dict()

In [None]:
#Averages every embedding for each class
avg_emb = {}
for k,v in s.items():
    avg_emb[k] = np.mean(v)

In [None]:
#Make a dataframe with class set as the index with its class embedding
avg_emb2class = pd.DataFrame([i for i in enumerate(avg_emb.values())])
avg_emb2class[0] += 1
avg_emb2class.sort_values(1, inplace=True)
avg_emb2class.set_index(1, inplace=True)

### Send the test images through the model to get embeddings

In [None]:
siamese_testset = DataLoader(siamese_testset,num_workers=6,batch_size=1,shuffle=False)
unknown_tails = iter(siamese_testset)
with torch.no_grad():
    net.eval()
    df_T = pd.DataFrame([1]).T
    for i, x in tqdm(unknown_tails):
        i = i.cuda()
        df_T.loc[int(x)] = float(net.get_embedding(i).data.cpu().numpy())

In [None]:
#Function to grab the closest number in an array then remove to find next closest
def closest(num, arr):
    curr = arr[0]
    for val in whip:
        if abs (num - val) < abs (num - curr):
            curr = val
    whip.remove(curr)
    return curr

In [None]:
for aa, i in enumerate(tqdm(df_T[0])):
    new4 = {}
    whip = [ic for ic in avg_emb2class.index]
    for x in range(4):
        vk = closest(i, whip)
        w_label = avg_emb2class.loc[vk][0]
        wn = list([str(a) for a, z in Trywhale.class_to_idx.items()if z == (w_label)])
        new4[x] = wn
        new44 = pd.DataFrame(new4)
        new44.at[0, 4] = 'new_whale'
        sub.at[aa, 'Id'] = ' '.join([e for e in new44.loc[0]])

In [None]:
sub

In [None]:
sub.to_csv(f'{pather}DEC27_submission.csv', index=False)