In [1]:
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [2]:
def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(65, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

class Config():
    training_dir = "att_faces/training/"
    testing_dir = "att_faces/testing/"
    train_batch_size = 64
    train_number_epochs = 100

In [3]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self,imageFolderDataset,transform=None,should_invert=True,should_gray=True):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        self.should_gray = should_gray        

    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1]==img1_tuple[1]:
                    break
        else:
            while True:
                #keep looping till a different class image is found
                
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] !=img1_tuple[1]:
                    break

        img0 = Image.open(img0_tuple[0])
        img1 = Image.open(img1_tuple[0])
        
        if self.should_gray:
            img0 = img0.convert("L")
            img1 = img1.convert("L")
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
        
        return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # TODO: design the architecture
        self.conv1 = nn.Conv2d(1,8,3,padding=1)
        self.maxPool1 = nn.MaxPool2d(2)
        # ReLU
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8,16,3,padding=1)
        self.maxPool2 = nn.MaxPool2d(2)
        # ReLU
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16,32,3,padding=1)
        self.maxPool3 = nn.MaxPool2d(2)
        # ReLU
        self.bn3 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(4608, 1024)
        # ReLU
        self.fc2 = nn.Linear(1024, 512)
        # ReLU
        self.fc3 = nn.Linear(512, 8)

    def forward_once(self, x):
        #TODO: implement the forward pass to get features for input image        
        z1 = self.conv1(x)
        m1 = self.maxPool1(z1)
        a1 = F.relu(m1)
        a1 = self.bn1(a1)
                
        z2 = self.conv2(a1)
        m2 = self.maxPool2(z2)
        a2 = F.relu(m2)
        a2 = self.bn2(a2)
                
        z3 = self.conv3(a2)
        m3 = self.maxPool3(z3)
        a3 = F.relu(m3)
        a3 = self.bn3(a3)

        z4 = self.fc1(a3.view(a3.shape[0], -1))
        a4 = F.relu(z4)
                
        z5 = self.fc2(a4)
        a5 = F.relu(z5)
        
        z6 = self.fc3(a5)
        
        return z6
        

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In [5]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        #TODO: argument output1 is f(x1), output2 is f(x2)
        # calculate the contrastive loss and return it
        euc_distances = F.pairwise_distance(output1, output2).view(-1,1)
        loss = (1-label)*(euc_distances**2) + label*(((self.margin - euc_distances).clamp(min=0))**2)
        loss = torch.mean(loss)
        return loss

In [6]:
def evaluate(dataiter, net, split, device):
    for i in range(2):
        x0,_,_ = next(dataiter)
        for j in range(10):
            _,x1,_ = next(dataiter)
            concatenated = torch.cat((x0,x1),0)
            output1,output2 = net(Variable(x0).to(device),Variable(x1).to(device))
            euclidean_distance = F.pairwise_distance(output1, output2)
            imshow(torchvision.utils.make_grid(concatenated),'%s, dissimilarity:%.2f'%(split, euclidean_distance.item())) 
            plt.savefig('%s_%d_%d.png'%(split,i, j))
            plt.close()

In [108]:
from tqdm import tqdm
def get_mean_and_std(dataloader):
    images = torch.zeros((370,100,100))
    idx = 0
    for image in dataloader:
        images[idx] = image[0].squeeze(0)
        idx += 1
    mean = torch.mean(images, dim=0)
    std = torch.std(images, dim=0)
    return mean, std

In [109]:
mean,std = get_mean_and_std(train_dataloader)

In [112]:
folder_dataset = dset.ImageFolder(root=Config.training_dir)
    
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset,
                                    transform=transforms.Compose([transforms.Resize((100,100)),
                                                                  transforms.ToTensor(),
                                                                  transforms.Normalize(mean, std)]),
                                    should_invert=False)
train_dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=10, batch_size=Config.train_batch_size)   

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = SiameseNetwork().to(device)
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005)

counter = []
loss_history = [] 
iteration_number= 0

for epoch in range(0,Config.train_number_epochs):
    for i, data in enumerate(train_dataloader,0):
        img0, img1 , label = data
        img0, img1 , label = img0.to(device), img1.to(device) , label.to(device)
        optimizer.zero_grad()
        output1,output2 = net(img0,img1)
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.item())

train_dataloader = DataLoader(siamese_dataset, shuffle=False, num_workers=10, batch_size=1)
dataiter = iter(train_dataloader)
evaluate(dataiter, net, 'train', device)

folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
                                    transform=transforms.Compose([transforms.Resize((100,100)),
                                                                  transforms.ToTensor()
                                                                  ])
                                   ,should_invert=False)

test_dataloader = DataLoader(siamese_dataset,num_workers=10,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
evaluate(dataiter, net, 'test', device)    

Epoch number 0
 Current loss 1.6904455423355103

Epoch number 1
 Current loss 0.5754615068435669

Epoch number 2
 Current loss 0.4397837221622467

Epoch number 3
 Current loss 0.2804790437221527

Epoch number 4
 Current loss 0.19823719561100006

Epoch number 5
 Current loss 0.22115091979503632

Epoch number 6
 Current loss 0.12760627269744873

Epoch number 7
 Current loss 0.209361732006073

Epoch number 8
 Current loss 0.0912131518125534

Epoch number 9
 Current loss 0.09882546961307526

Epoch number 10
 Current loss 0.0905810296535492

Epoch number 11
 Current loss 0.0685461238026619

Epoch number 12
 Current loss 0.06907980144023895

Epoch number 13
 Current loss 0.03984855115413666

Epoch number 14
 Current loss 0.039949316531419754

Epoch number 15
 Current loss 0.0965958833694458

Epoch number 16
 Current loss 0.03565568849444389

Epoch number 17
 Current loss 0.05688484013080597

Epoch number 18
 Current loss 0.03184136748313904

Epoch number 19
 Current loss 0.029206089675426483

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the v

In [11]:
import plotly.plotly as py
import plotly.graph_objs as go
import numpy as np

In [113]:
loss_history

[1.6904455423355103,
 0.5754615068435669,
 0.4397837221622467,
 0.2804790437221527,
 0.19823719561100006,
 0.22115091979503632,
 0.12760627269744873,
 0.209361732006073,
 0.0912131518125534,
 0.09882546961307526,
 0.0905810296535492,
 0.0685461238026619,
 0.06907980144023895,
 0.03984855115413666,
 0.039949316531419754,
 0.0965958833694458,
 0.03565568849444389,
 0.05688484013080597,
 0.03184136748313904,
 0.029206089675426483,
 0.026910856366157532,
 0.01840069144964218,
 0.030751390382647514,
 0.03653538599610329,
 0.014518612995743752,
 0.024776320904493332,
 0.03055027313530445,
 0.024485645815730095,
 0.010981111787259579,
 0.015117126516997814,
 0.011697828769683838,
 0.01418964471668005,
 0.006581231486052275,
 0.007254813332110643,
 0.00833601038902998,
 0.0074264882132411,
 0.005201428197324276,
 0.01122827548533678,
 0.00777999684214592,
 0.013007022440433502,
 0.009342709556221962,
 0.0056999665684998035,
 0.005728492978960276,
 0.005195084027945995,
 0.008028329350054264,
 

In [43]:
loss_history_simple = [2.004492998123169,0.4032353162765503,0.32568204402923584,0.2171175181865692,0.20326577126979828,0.14600753784179688,0.13853979110717773,0.1044771671295166,0.09993784874677658,0.0751296654343605,0.05321107059717178,0.0801149383187294,0.07795143872499466,0.05339547619223595,0.08486290276050568,0.029832255095243454,0.03434082865715027,0.029086865484714508,0.032689034938812256,0.037102289497852325,0.0278323944658041,0.024693816900253296,0.01813843473792076,0.023272279649972916,0.015175173990428448,0.015235042199492455,0.010254344902932644,0.02685128152370453,0.019651245325803757,0.01973552815616131,0.015522783622145653,0.01158134825527668,0.014464851468801498,0.023881783708930016,0.011985798366367817,0.010766906663775444,0.016548587009310722,0.010207918472588062,0.011326385661959648,0.012325907126069069,0.008463562466204166,0.010520274750888348,0.012202589772641659,0.0063684433698654175,0.011463809758424759,0.005379859358072281,0.007480606436729431,0.006172084249556065,0.004802504554390907,0.005046031903475523,0.005612905602902174,0.0064800274558365345,0.004149800166487694,0.008091883733868599,0.004458779469132423,0.005127860698848963,0.008432081900537014,0.007117325905710459,0.01082070916891098,0.00427946588024497,0.005698100198060274,0.004849138204008341,0.004517256282269955,0.005028571933507919,0.005384495947510004,0.003996153362095356,0.006916731130331755,0.00555469049140811,0.004786317702382803,0.004698920529335737,0.004839837085455656,0.0053241560235619545,0.005228434223681688,0.004067743197083473,0.00489381467923522,0.0034766942262649536,0.0044547561556100845,0.0089314179494977,0.005077831447124481,0.004841863643378019,0.0034785596653819084,0.004436390474438667,0.005524823907762766,0.00225108047015965,0.0044188895262777805,0.0036274550948292017,0.0035408998373895884,0.003085918491706252,0.0029984675347805023,0.0024481925647705793,0.002426530234515667,0.004022778477519751,0.0019963716622442007,0.0021325729321688414,0.0019137444905936718,0.002840265166014433,0.0032041354570537806,0.0018502279417589307,0.0045061311684548855,0.0015171277336776257]
loss_history_change = [1.8403959274291992, 0.41868335008621216, 0.27898257970809937, 0.16674131155014038, 0.1491120159626007, 0.08871656656265259, 0.18025213479995728, 0.056647781282663345, 0.09543798863887787, 0.07643908262252808, 0.10679472237825394, 0.13478484749794006, 0.06725107133388519, 0.059527575969696045, 0.04184179753065109, 0.03874550759792328, 0.0456191748380661, 0.032816141843795776, 0.033699847757816315, 0.0269613079726696, 0.06431301683187485, 0.039705291390419006, 0.025552116334438324, 0.029886342585086823, 0.01608881540596485, 0.021837681531906128, 0.02505965158343315, 0.02299339696764946, 0.01259642280638218, 0.016523580998182297, 0.022832956165075302, 0.017558995634317398, 0.01079526636749506, 0.01585199125111103, 0.010289555415511131, 0.034904126077890396, 0.013666396029293537, 0.014933628961443901, 0.012351276353001595, 0.01052769087255001, 0.006821881979703903, 0.013541504740715027, 0.019387859851121902, 0.006530501879751682, 0.012358814477920532, 0.00968804582953453, 0.00872714538127184, 0.006737476214766502, 0.008688675239682198, 0.005961311981081963, 0.006322861183434725, 0.006466441787779331, 0.0053357710130512714, 0.005648092366755009, 0.006502725183963776, 0.005812946241348982, 0.008841054514050484, 0.003004188183695078, 0.004964965395629406, 0.0036152724642306566, 0.014889519661664963, 0.0059528290294110775, 0.004709242843091488, 0.005400128196924925, 0.004518333822488785, 0.005707269534468651, 0.003616920206695795, 0.004159178119152784, 0.006849345285445452, 0.0030636400915682316, 0.00574517622590065, 0.010539848357439041, 0.003283688798546791, 0.004299634136259556, 0.004218427464365959, 0.004886045586317778, 0.0057366993278265, 0.010670450516045094, 0.005874735303223133, 0.004429766908288002, 0.007891305722296238, 0.004132836125791073, 0.0022304467856884003, 0.0057703647762537, 0.004807622637599707, 0.0040344372391700745, 0.006578688044101, 0.004907539114356041, 0.0039510829374194145, 0.005668643396347761, 0.0036194603890180588, 0.004237660206854343, 0.0055985040962696075, 0.004397407174110413, 0.00558597594499588, 0.0030481922440230846, 0.004669278860092163, 0.0031238803640007973, 0.003587383544072509, 0.0018048817291855812]

In [114]:
# Create a trace
trace = go.Scatter(
    x = list(range(len(loss_history))),
    y = loss_history
)

data = [trace]

layout = go.Layout(
    yaxis=dict(
        tickmode='linear',
        tick0=0,
        dtick=0.1,
        ticklen=8,
        tickcolor='#000'
    )
)
fig = go.Figure(data=data, layout=layout)
py.iplot(fig, filename='axes-ticks')

In [98]:
test = torch.tensor([[[10,20],[0,1]],[[1000,2000],[100,100]]], dtype=torch.float32)
torch.std(test, dim=1)

tensor([[   7.0711,   13.4350],
        [ 636.3961, 1343.5029]])