In [None]:
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torchvision

from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision import transforms

In [None]:
def show_sample_image(image):
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.imshow(image_rgb)
    plt.show()

In [None]:
class SiameseNetwork(nn.Module):
    """
        Siamese Network for estimating image similarity
        Taken from https://github.com/pytorch/examples/blob/main/siamese_network/main.py
        Modified to use ResNet50 instead of ResNet18
    """
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.resnet = torchvision.models.resnet50(weights=None)
        
        self.fc_in_features = self.resnet.fc.in_features
        
        # remove last layer of ResNet50
        self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
        
        self.fc = nn.Sequential(
            nn.Linear(self.fc_in_features * 2, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
        )
        
        self.sigmoid = nn.Sigmoid()

        # initialize the weights
        self.resnet.apply(self.init_weights)
        self.fc.apply(self.init_weights)
        
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
    
    def forward_once(self, x):
        output = self.resnet(x)
        output = output.view(output.size()[0], -1)
        return output

    def forward(self, input1, input2):
        # get features for both images
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        # concatenate features
        output = torch.cat((output1, output2), 1)

        # pass the concatenation to the linear layers
        output = self.fc(output)

        # pass the out of the linear layers to sigmoid layer
        output = self.sigmoid(output)
        
        return output

In [None]:
class MiniImageNet(Dataset):
    def __init__(self, root, transform):
        super(MiniImageNet, self).__init__()
        self.data = datasets.ImageFolder(root, transform=transform)
        
        # images for each class are contained within separate folders
        self.num_classes = len(os.listdir(root))
        
        # each class contains an equal number of images (600)
        self.num_samples_per_class = len(self) // self.num_classes 
    
    def __getitem__(self, index):
        # based on https://github.com/pytorch/examples/blob/main/siamese_network/main.py
        selected_class = random.choice(range(0, self.num_classes))
        
        # index range for selected class
        lower_range_index = selected_class * self.num_samples_per_class
        upper_range_index = (selected_class + 1) * self.num_samples_per_class
        selected_class_range = range(lower_range_index, upper_range_index)
        
        # get indices of 2 images from the same class 
        if index % 2 == 0: 
            [index_1, index_2] = random.sample(selected_class_range, 2)
            target = torch.tensor(1, dtype=torch.float)
            
        # get indices of 2 images from different classes
        else:
            non_selected_class_range = [i for i in range(0, len(self)) 
                                        if i not in selected_class_range]
            index_1 = random.choice(selected_class_range)
            index_2 = random.choice(non_selected_class_range)
            target = torch.tensor(0, dtype=torch.float)
            
        image_1 = self.data[index_1][0]
        image_2 = self.data[index_2][0]
        
        return image_1, image_2, target
    
    def __len__(self):
        return len(self.data)

In [None]:
train_directory = './data/train'

transform = transforms.Compose([
    transforms.ToTensor()
]) 

train_data = MiniImageNet(train_directory, transform=transform) # add transformation directly

train_dataloader = DataLoader(train_data)

In [None]:
val_directory = './data/val'

val_data = MiniImageNet(val_directory, transform=transform) # add transformation directly

val_dataloader = DataLoader(val_data)

In [None]:
model = SiameseNetwork()