https://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch

In [None]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

In [None]:
%matplotlib inline
import random
import matplotlib.pyplot as plt
from PIL import Image
import PIL.ImageOps    
from tqdm.auto import tqdm

import numpy as np
import pandas as pd

import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import torchvision.utils
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from transformers import AutoImageProcessor, ResNetForImageClassification

# Configuration Class

In [None]:
class Config():
    training_dir = "./data/DigiFace1M"
    testing_dir = "./data/DigiFace1M"
    train_batch_size = 64
    train_number_epochs = 100
    image_mean = np.array([0.485, 0.456, 0.406 ])
    image_std = np.array([0.229, 0.224, 0.225])
    save_freq = 10
    save_prefix = "runs/overfit_test_0"
    save_path = save_prefix+"/overfit_100epochs.pth"


# Helper functions
Set of helper functions

In [None]:
def imshow(img,text=None,should_save=False):
    npimg = np.asarray(img)#.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    # npimg = np.transpose(npimg, (1, 2, 0))
    plt.imshow(npimg)
    plt.show()    

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

In [None]:
def prep_images(img_path, transform=None):
    img = Image.open(img_path).convert("RGB")

    if transform:
        img = transform(img)
        
    return img

# Model

In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self, model_path = "microsoft/resnet-50"):
        super(SiameseNetwork, self).__init__()
        self.cnn = ResNetForImageClassification.from_pretrained(model_path).resnet

        self.fc1 = nn.Sequential(
            nn.Linear(2048, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 32))

    def forward_once(self, x):
        output = self.cnn(x).pooler_output
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

# Dataset
Custom Dataset Class

This dataset generates a pair of images. 0 for geniune pair and 1 for imposter pair


In [None]:

class SiameseNetworkDataset(Dataset):
    def __init__(self,imageFolderDataset,transform=None,should_invert=False):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        
        images, ids = list(zip(*self.imageFolderDataset.imgs))        
        self.data_df = pd.DataFrame({'img':images, 'ids':ids})
        
    def __getitem__(self,index):
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            id_img = random.choice(self.data_df.ids.unique())
            df = self.data_df[self.data_df.ids==id_img].sample(n=2).reset_index(drop = True)
        else:
            id_img1, id_img2 = random.sample(self.data_df.ids.unique().tolist(),2)
            df1 = self.data_df[self.data_df.ids==id_img1].sample(n=1)
            df2 = self.data_df[self.data_df.ids==id_img2].sample(n=1)
            df = pd.concat([df1, df2]).reset_index(drop = True)

        img0 = Image.open(df.iloc[0].img).convert("RGB")
        img1 = Image.open(df.iloc[1].img).convert("RGB")
        # img0 = img0.convert("L")
        # img1 = img1.convert("L")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        
        return (df.iloc[0].img,img0), (df.iloc[1].img,img1) , torch.from_numpy(np.array([1-int(should_get_same_class)],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)
    

# Data Visualization

In [None]:
folder_dataset = dset.ImageFolder(root=Config.training_dir)

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(
                                    mean=Config.image_mean,
                                    std=Config.image_std)])
inv_normalize = transforms.Normalize(
    mean=-Config.image_mean/Config.image_std,
    std=1/Config.image_std
)

In [None]:
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset, transform = transform,should_invert=False)

In [None]:
# x = siamese_dataset[0]
# x

In [None]:
# x[0][1].shape
# imshow(inv_normalize(x[0][1]).permute(1,2,0))
# imshow(inv_normalize(x[1][1]).permute(1,2,0))

# Training

In [None]:
device = "cuda:1"

In [None]:
train_dataloader = DataLoader(siamese_dataset,
                        shuffle=True,
                        pin_memory=False,
                        num_workers=0,
                        batch_size=Config.train_batch_size)

In [None]:
net = SiameseNetwork().to(device)

In [None]:
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )

In [None]:
save_path = Config.save_path

running_loss = 0
writer = SummaryWriter(save_prefix)

for epoch in range(0,Config.train_number_epochs):
    pbar = tqdm(train_dataloader)
    epoch_loss = 0
    for i, data in enumerate(pbar):
        (_,img0), (_,img1), label = data
        img0, img1 , label = img0.to(device), img1.to(device) , label.to(device)
        optimizer.zero_grad()
        output1,output2 = net(img0,img1)
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        
        running_loss += loss_contrastive.item()
        if (i+1) %Config.save_freq == 0 :
            writer.add_scalar('training loss', running_loss/Config.save_freq, epoch * len(train_dataloader) + (i+1)/Config.save_freq)
            running_loss = 0

    pbar.set_description(f"Epoch {epoch}: loss : {epoch_loss/len(train_dataloader):.5f}")

writer.close()

In [None]:
torch.save(net.state_dict(), Config.save_path)

# Evaluate

In [None]:
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
num_same_id = 4
num_diff_id = 8

In [None]:
images, ids = list(zip(*folder_dataset_test.imgs))    
df = pd.DataFrame({'img':images, 'ids':ids})
img0 = df.sample(n=1)

In [None]:
test_id = img0.iloc[0].ids
test_img = img0.iloc[0].img

In [None]:
df_test = pd.concat([df[df.ids == test_id].sample(n=num_same_id), df[df.ids != test_id].sample(n=num_diff_id)])

In [None]:
x0 = prep_images(test_img, transform=transform)[None, :]

In [None]:
for i in range(len(df_test)):
    x1 = prep_images(df_test.iloc[i].img, transform=transform)[None, :]
    label = int(df_test.iloc[i].ids != test_id)

    concatenated = torch.cat((x0,x1),0)
    
    output1,output2 = net(x0.to(device),x1.to(device))
    euclidean_distance = F.pairwise_distance(output1, output2)
    imshow(torchvision.utils.make_grid(inv_normalize(concatenated)).permute(1,2,0),f'Label: {label}; Dissimilarity: {euclidean_distance.item():.2f}')

In [None]:
x0.shape
x1.shape

In [None]:
net.cnn.embedder(x0)

In [None]:
dir(net.re)