In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchvision

import matplotlib.pyplot as plt
import PIL
import numpy
import albumentations
import random
import torch
from albumentations.pytorch.transforms import ToTensor
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

In [23]:
class SiameseVanillaDataset():
    '''
    Author @Pranav Pandey, Date: 04_03_2020.
    This class is for loading dataset from a given folder in pairs with a label given to the pair of images;
    if they are simillar (1) or different (0) to each other.
    '''

    def __init__(self, imageFolderDataset, img_height, img_width, mean, std, no_template=False, transform=False):
        self.imageFolderDataset = imageFolderDataset
        self.no_template = no_template

        if transform:
            self.aug = albumentations.Compose([
                albumentations.Resize(img_height, img_width, always_apply=True),
                albumentations.ShiftScaleRotate(shift_limit=0.0625,
                                scale_limit=0.1,
                                rotate_limit=5,
                                p=0.9),
                albumentations.Normalize(mean, std, always_apply= True),
                ToTensor()
            ])
        else:
            self.aug = albumentations.Compose([
                albumentations.Resize(img_height, img_width, always_apply=True),
                albumentations.Normalize(mean, std, always_apply= True),
                ToTensor()
            ])
            self.aug_2 = transforms.Compose([transforms.Resize((520,200)),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                            ])
    
    def __getitem__(self,index):
        img0_tuple = random.choice(self.imageFolderDataset.imgs)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                #keep looping till the same class image is found
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1]==img1_tuple[1]:
                    break
        else:
            while True:
                #keep looping till a different class image is found
                
                img1_tuple = random.choice(self.imageFolderDataset.imgs) 
                if img0_tuple[1] !=img1_tuple[1]:
                    break

        img0_raw = Image.open(img0_tuple[0]).convert(mode='RGB')
        img0 = self.aug_2(img0_raw)
        if(self.no_template == True):
            img1_raw = img0_raw.filter(PIL.ImageFilter.GaussianBlur(radius=5))
            img1 = self.aug_2(img1_raw)
        else:
            img1_raw = Image.open(img1_tuple[0]).convert(mode='RGB')
            img1 = self.aug_2(img1_raw)
        # img0 = torch.from_numpy(np.moveaxis(img0 / (255.0 if img0.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
        # img1 = torch.from_numpy(np.moveaxis(img1 / (255.0 if img1.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
        
        return img0, img1 , torch.from_numpy(np.array([int(should_get_same_class)],dtype=np.float32))
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)

In [17]:
class SiameseVanilla(nn.Module):
    def __init__(self):
        super(SiameseVanilla, self).__init__()
        self.Convolve = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=128, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=2, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
        )
        self.Linear = nn.Sequential(
            nn.Linear(10912,256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 24)
        )
    def forward(self, x_1, x_2):
        '''
        Keeping the passing of 2 inputs through the network explicit here for the sake of transperancy
        '''
        x_1 = self.Convolve(x_1)
        x_1 = x_1.reshape(x_1.size()[0], -1)
        x_1 = self.Linear(x_1)

        x_2 = self.Convolve(x_2)
        x_2 = x_2.reshape(x_2.size()[0], -1)
        x_2 = self.Linear(x_2)
        return x_1, x_2

In [24]:
imgFD_test = torchvision.datasets.ImageFolder(root="/home/transpacks/Repos/Siamese-Network/test/")
imgFD_train = torchvision.datasets.ImageFolder(root="/home/transpacks/Repos/Siamese-Network/input/")

test_dataset = SiameseVanillaDataset(
    imageFolderDataset = imgFD_test,
    img_height = 520,
    img_width = 200,
    mean = 0,
    std = 0,
    no_template=True
)
test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset,
    batch_size = 1,
    shuffle = True,
    num_workers=4,
    drop_last=True)

train_dataset = SiameseVanillaDataset(
    imageFolderDataset = imgFD_train,
    img_height = 520,
    img_width = 200,
    mean = 0,
    std = 0,
    no_template=True
)
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = 1,
    shuffle = True,
    num_workers=4,
    drop_last=True)

In [21]:
model = SiameseVanilla()
model.load_state_dict(torch.load("/home/transpacks/Repos/Siamese-Network/results/Random_Test/model_95.bin"))
model.to("cuda:0")

SiameseVanilla(
  (Convolve): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1))
    (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
    (12): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1))
    (15): ReLU()
  )
  (Linear): Sequential(
    (0): Linea

## Training Accuracy

In [25]:
dataiter = iter(train_loader)
accuracy = 0
preds = []
labels = []

for i, (x0,x1,label) in enumerate(dataiter):
    #concatenated = torch.cat((x0,x1),0)
    
    output1,output2 = model(x0.to("cuda:0"),x1.to("cuda:0"))
    euclidean_distance = F.pairwise_distance(output1, output2)
    labels.append(label.item())
    if (euclidean_distance <= 0.5):
        preds.append(0)
    else:
        preds.append(1)

print(accuracy_score(labels, preds))

0.5036019062396099


In [26]:
print(confusion_matrix(labels, preds))

[[1333 3179]
 [1300 3211]]


## Testing

In [26]:
dataiter = iter(test_loader)
x0,_,label1 = next(dataiter)

for i in range(10):
    _,x1,label2 = next(dataiter)
    concatenated = torch.cat((x0,x1),0)
    
    output1,output2 = model(x0.to("cuda:0"),x1.to("cuda:0"))
    euclidean_distance = F.pairwise_distance(output1, output2)
    print(euclidean_distance.item())
    print("Label 1: {0} and Label 2: {1}".format(label1.item(), label2.item()))
    #plt.imshow(torchvision.utils.make_grid(concatenated),'Similarity: {:.2f}'.format(euclidean_distance.item()))

0.7145547270774841
Label 1: 0.0 and Label 2: 0.0
0.8847765922546387
Label 1: 0.0 and Label 2: 1.0
0.5884808897972107
Label 1: 0.0 and Label 2: 1.0
0.6108886003494263
Label 1: 0.0 and Label 2: 1.0
0.4660487174987793
Label 1: 0.0 and Label 2: 1.0
0.8675166964530945
Label 1: 0.0 and Label 2: 0.0
0.5574768781661987
Label 1: 0.0 and Label 2: 0.0
0.5961430668830872
Label 1: 0.0 and Label 2: 1.0
0.6219456195831299
Label 1: 0.0 and Label 2: 0.0
0.9052636027336121
Label 1: 0.0 and Label 2: 0.0
