In [None]:
%matplotlib inline
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torchvision.utils
import numpy as np
import random
import os
from PIL import Image
import torch
from torch.autograd import Variable
import PIL.ImageOps    
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from prettytable import PrettyTable

In [None]:
def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

In [None]:
class SiameseDataset(Dataset):
    def __init__(self,root_dir,imageFolderDataset, transform=None, should_invert=True):
        self.root_dir=root_dir
        self.imageFolderDataset=imageFolderDataset
        self.transform=transform
        self.should_invert=should_invert
        
    def __getitem__(self,index):
        should_get_same_class=random.randint(0,1)
        #print(should_get_same_class)
        img1_tuple=random.choice(self.imageFolderDataset.imgs)
        if should_get_same_class:
            while True:
                img2_tuple=random.choice(self.imageFolderDataset.imgs)
                if img1_tuple[1]==img2_tuple[1]:
                    break
        else:
            while True:
                img2_tuple=random.choice(self.imageFolderDataset.imgs)
                if img1_tuple[1]!=img2_tuple[1]:
                    break
        
        img1=Image.open(img1_tuple[0])
        img2=Image.open(img2_tuple[0])
        
        if self.should_invert:
            img1=PIL.ImageOps.invert(img1)
            img2=PIL.ImageOps.invert(img2)
        
        if self.transform is not None:
            img1=self.transform(img1)
            img2=self.transform(img2)
            
        return img1, img2, torch.from_numpy(np.array([(img1_tuple[1]!=img2_tuple[1])], dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [None]:
img_size=105 
epochs=25
batchsize=64

In [None]:
#train_dir='./data/Omniglot/alphabet_dataset/images_background/'
#train_dir='./data/Omniglot/character_dataset/train/'
train_dir='./data/Omniglot/changed/train/'
train_imagefolder=dset.ImageFolder(train_dir)

In [None]:
train_dataset=SiameseDataset(root_dir=train_dir, imageFolderDataset=train_imagefolder,
                             transform=transforms.Compose([transforms.Resize((img_size,img_size)),
                                                           transforms.RandomHorizontalFlip(),
                                                           transforms.RandomResizedCrop(img_size, scale=(0.8,1.0)),
                                                           transforms.ToTensor()
                                                          ]),
                             should_invert=False)

train_dataloader=DataLoader(train_dataset, batch_size=batchsize, num_workers=0, shuffle=True)

In [None]:
visualize_dataloader = DataLoader(train_dataset,
                        shuffle=True,
                        num_workers=0,
                        batch_size=8)
dataiter = iter(visualize_dataloader)


example_batch = next(dataiter)
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy()) 

In [None]:
#Koch et al.   #img_size=105
class SiameseKoch(nn.Module):
    def __init__(self):
        super(SiameseKoch, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 10),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 7),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 4),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 4),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        self.fc1=nn.Sequential(
            nn.Linear(9216, 4096),
            nn.Sigmoid())
        
        self.out=nn.Linear(4096,1)
        
    def forward_once(self,inp):
        inp=self.conv(inp)
        inp=inp.view(inp.size()[0], -1)
        inp=self.fc1(inp)
        return inp
        
    def forward(self, inp1, inp2):
        out1=self.forward_once(inp1)
        out2=self.forward_once(inp2)
        #return out1,out2
        dis=torch.abs(out2-out1)
        out=self.out(dis)
        return out

In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [None]:
net=SiameseKoch().to(device)
#criterion=ContrastiveLoss()
criterion=nn.BCEWithLogitsLoss()
optimizer=optim.Adam(net.parameters(), lr=0.001)
#optimizer = optim.SGD(net.parameters(), lr=0.001)

In [None]:
valid_dir='./data/Omniglot/changed/valid/'
valid_imagefolder=dset.ImageFolder(valid_dir)

In [None]:
valid_dataset=SiameseDataset(root_dir=valid_dir, imageFolderDataset=valid_imagefolder,
                             transform=transforms.Compose([transforms.Resize((img_size,img_size)),
                                                           transforms.RandomHorizontalFlip(),
                                                           transforms.RandomResizedCrop(img_size, scale=(0.8,1.0)),
                                                           transforms.ToTensor()
                                                          ]),
                             should_invert=False)

valid_dataloader=DataLoader(valid_dataset, batch_size=batchsize, num_workers=0, shuffle=True)

In [None]:
def train(net, train_dataloader, valid_dataloader, epochs, criterion):
    train_loss=[] #training loss for every epoch
    valid_loss=[] #validation loss for every epoch
    sum_train_loss=0.0 #sum of training losses for every epoch
    sum_valid_loss=0.0 #sum of validation losses for every epoch
    
    for epoch in range(1, epochs+1):
        train_epoch_loss=0.0
        net.train()
        for i, data in enumerate(train_dataloader,0):
            img1, img2, label = data
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            
            label = label.float()
            output = net(img1, img2)
            loss = criterion(output, label)
            
            #output1, output2 = net(img1, img2)
            #loss = criterion(output1, output2, label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_epoch_loss = train_epoch_loss + ((1/(i+1)) * (loss.item() - train_epoch_loss))
            
        train_loss.append(train_epoch_loss)
        sum_train_loss+=train_epoch_loss
        
        valid_epoch_loss=0.0
        correct=0
        accuracy=0
        with torch.no_grad():
            net.eval()
            for i, data in enumerate(valid_dataloader,0):
                img1, img2, label = data
                img1, img2, label = img1.to(device), img2.to(device), label.to(device)
                
                output = net(img1, img2)
                loss = criterion(output, label)
                
                #output1, output2 = net(img1, img2)
                #loss = criterion(output1, output2, label)
                                    
                valid_epoch_loss = valid_epoch_loss + ((1/(i+1)) * (loss.item() - valid_epoch_loss))
                
        valid_loss.append(valid_epoch_loss)
        sum_valid_loss+=valid_epoch_loss
        
        print("Epoch {}/{}\n Train loss : {} \t Valid loss {}\n"
             .format(epoch, epochs, train_epoch_loss, valid_epoch_loss))
        
    print("Average training loss after {} epochs : {}".format(epochs, sum_train_loss/epochs))
    print("Average validation loss after {} epochs : {}".format(epochs, sum_valid_loss/epochs))
    
    return train_loss, valid_loss

In [None]:
train_losses, valid_losses = train(net, train_dataloader, valid_dataloader, epochs, criterion)

In [None]:
plt.xlabel('epochs')
plt.ylabel('loss')
plt.plot(train_losses, label="Train loss")
plt.plot(valid_losses, label="Validation loss")
plt.legend(bbox_to_anchor=(1.1,1.0), loc='upper left')
plt.savefig('siamese.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
test_dir='./data/Omniglot/changed/test/'
test_imagefolder=dset.ImageFolder(test_dir)

In [None]:
test_dataset=SiameseDataset(root_dir=test_dir, imageFolderDataset=test_imagefolder,
                             transform=transforms.Compose([transforms.Resize((img_size,img_size)),
                                                           transforms.RandomHorizontalFlip(),
                                                           transforms.RandomResizedCrop(img_size, scale=(0.8,1.0)),
                                                           transforms.ToTensor()
                                                          ]),
                             should_invert=False)

test_dataloader=DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=True)

In [None]:
def eval(net, test_dataloader):
    with torch.no_grad():
        net.eval()
        count=100
        correct=0
        accuracy=0
        dataiter = iter(test_dataloader)
    
        for i in range(count):
            img1,img2,label=next(dataiter)
            
            cat=torch.cat((img1, img2),0)
            output = net(Variable(img1).to(device), Variable(img2).to(device))
            prediction = torch.sigmoid(output)
            total = label.size(0)
            
            for j in range(output.size(0)):
                if (prediction[j]>0.5) and (label[j]==1):
                    correct+=1
                elif (prediction[j]<0.5) and (label[j]==0):
                    correct+=1
            accuracy+=correct/total
            correct=0
            imshow(torchvision.utils.make_grid(cat),'Pred : {:.2f} Label : {}'.format(prediction.item(),label.item()))   
    return accuracy

In [None]:
acc=eval(net, test_dataloader)
print("Accuracy of the network : ",acc)