In [None]:
!rm -r data
!rm -r GTdb_crop.zip
!rm -r GTdb_crop
!wget http://www.anefian.com/research/GTdb_crop.zip
!unzip GTdb_crop.zip -d GTdb_crop

In [None]:
import os, glob, random
import shutil
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as img_transf
from torchvision import datasets as ds
from PIL import Image
from torch.autograd import Variable
import torchvision
import torch
import torch.nn.functional as F
from torch import nn
from torch import optim
import numpy as np
import matplotlib.pyplot as plt

In [None]:
os.mkdir("data")
for i in range(1, 51):
    pattern = "s"+str(i).zfill(2)
    path = 'data/'+str(pattern)
    f = sorted(glob.glob(os.path.join("GTdb_crop/cropped_faces/", pattern+"*")))
    os.mkdir(path)
    [shutil.copy(i, path) for i in f]

In [None]:
src = "data/"
d = os.listdir("data")
for i in range(3):
    r = random.choice(d)
    shutil.copytree(src + r, "data/testing/"+r)
    d.remove(r)
else:
    for i in sorted(d):
        shutil.copytree(src + i, "data/training/"+i)

In [None]:
class SiaDataset(Dataset):
    def __init__(self, imageDir, image_transforms=None, gray_scale=False):
        self.imageDir = imageDir
        self.image_transforms = image_transforms
        self.gray_scale = gray_scale
    
    def __getitem__(self, idx):
        im1 = self.imageDir.imgs[idx]
        match = random.randint(0,1) 
        if match:
            im2 = self.imageDir.imgs[idx]
        else:
            im2 = random.choice(self.imageDir.imgs)

        img1 = Image.open(im1[0]).convert("RGB")
        img2 = Image.open(im2[0]).convert("RGB")
        label = torch.from_numpy(np.array([int(im1[1]==im2[1])], dtype=np.float32))
        
        if self.gray_scale:
            img1 = img1.convert("L")
            img2 = img2.convert("L")
        
        if self.image_transforms:
            img1 = self.image_transforms(img1)
            img2 = self.image_transforms(img2)
        return img1, img2, label
        
    def __len__(self):
        return len(self.imageDir.imgs)
  

In [None]:
class Siamese(nn.Module):
    def __init__(self):
        super(Siamese, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=10),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),
            nn.BatchNorm2d(64)
        )
        self.cnn2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=7),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),
            nn.BatchNorm2d(128)
        )
        
        self.cnn3 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2),
            nn.BatchNorm2d(128)
        )
        
        self.cnn4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256)
        )
        
        self.fc1 = nn.Sequential(
            nn.Linear(256*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 8),
        )
    
    def forward(self, img1, img2):
        img1 = self.cnn1(img1)
        img1 = self.cnn2(img1)
        img1 = self.cnn3(img1)
        img1 = self.cnn4(img1)
        img1 = img1.view(img1.size()[0], -1)
        img1 = self.fc1(img1)
        
        img2 = self.cnn1(img2)
        img2 = self.cnn2(img2)
        img2 = self.cnn3(img2)
        img2 = self.cnn4(img2)
        img2 = y.view(img2.size(0), -1)
        img2 = self.fc1(img2)
        return img1, img2
        

In [None]:
class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [None]:
def show(img,text=None):
    img = img.numpy()
    plt.axis("off")
    plt.text(75,120 , text, fontweight='bold')
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

In [None]:
im_trans = img_transf.Compose([img_transf.Resize((105,105)),img_transf.ToTensor()])

folder_dataset = ds.ImageFolder("data/training")

sia_dataset = SiaDataset(imageDir=folder_dataset,
                                        image_transforms=im_trans,
                                        gray_scale=False)
train_dataloader = DataLoader(sia_dataset,
                        shuffle=True,
                        num_workers=8,
                        batch_size=64)

folder_dataset = ds.ImageFolder("data/testing")

siatest_dataset = SiaDataset(imageDir=folder_dataset,
                                        image_transforms=im_trans,
                                        gray_scale=False)


test_dataloader = DataLoader(siatest_dataset,num_workers=6,batch_size=1,shuffle=True)


In [None]:
net = Siamese().cuda()
loss_fun = ContrastiveLoss() 
optims = optim.Adam(net.parameters(),lr = 0.0003 )

In [None]:
counter = []
l = [] 
iter_num= 0

for epoch in range(0,100):
    for i, data in enumerate(train_dataloader,0):
        img0, img1 , label = data
        img0, img1 , label = img0.cuda(), img1.cuda() , label.cuda()
        optims.zero_grad()
        op1,op2 = net(img0,img1)
        loss = loss_fun(op1,op2,label)
        loss.backward()
        optims.step()
        if i %12 == 0 :
            print("Epoch {} with {} loss\n".format(epoch,loss.item()))
            iter_num +=5
            counter.append(iter_num)
            l.append(loss.item())
plt.plot(counter,l)
plt.show()

In [None]:
for i, data in enumerate(test_dataloader, 0):
    img0, img1 , label = data
    concatenated = torch.cat((img0,img1))
    output1,output2 = net(Variable(img0).cuda(),Variable(img1).cuda())
    distance = F.pairwise_distance(output1, output2)
    show(torchvision.utils.make_grid(concatenated),'Missmatch: {:.3f}'.format(distance.item()))