# Training the Siamese Network
The Siamese Network is a neural network architecture used for image similarity estimation. It consists of two identical networks, each taking an input image and producing a feature vector. The feature vectors are then compared using a contrastive loss function or more recently a Triple Loss function to determine the similarity between the images. The network is trained using a dataset of image pairs, where positive pairs are from the same class and negative pairs are from different classes. The training process involves optimizing the network's parameters to minimize the contrastive loss.

In [1]:
%pip install -r "/kaggle/input/requirements-siamese/requirements.txt"

Note: you may need to restart the kernel to use updated packages.


In [2]:
from __future__ import print_function
import argparse, random, copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR

In [5]:
class Config():
    training_dir = "/kaggle/input/classify-by-brand-dataset-fixed/classify_by_brand/classify_by_brand_dataset"
    testing_dir = "/kaggle/input/classify-by-brand-dataset-fixed/classify_by_brand/classify_by_brand_dataset"
    annotations_file = "/kaggle/working/training_dataset.csv"
    test_annotations_file = '/kaggle/working/test_dataset.csv'
    train_batch_size = 64
    train_number_epochs = 20
    transform = transforms.Compose([        # Defining a variable transforms
                 transforms.Resize(256),                # Resize the image to 256×256 pixels
                 transforms.CenterCrop(224),            # Crop the image to 224×224 pixels about the center
                 transforms.ToTensor(),                 # Convert the image to PyTorch Tensor data type
                 transforms.Normalize(                  # Normalize the image
                 mean=[0.485, 0.456, 0.406],            # Mean and std of image as also used when training the network
                 std=[0.229, 0.224, 0.225]      
            )])
    NUM_CLASSES = 10

In [46]:
# Contrastive Loss definition
# In this project Triplet loss is used which has proved to be more effective
# for face recognition
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [6]:
import os
import pandas as pd
from PIL import Image

# defining the dataset
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(label_dir)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [18]:
"""
    The following is a piece of code from the pytorch Siamese Network example
    which has been adapted to fit the needs of this project
    The original implementation can be found at https://github.com/pytorch/examples/tree/main/siamese_network
"""

class SiameseNetwork(nn.Module):
    """
        Siamese network for image similarity estimation.
        The network is composed of two identical networks, one for each input.
        The output of each network is concatenated and passed to a linear layer. 
        The output of the linear layer passed through a sigmoid function.
        `"FaceNet" <https://arxiv.org/pdf/1503.03832.pdf>`_ is a variant of the Siamese network.
        This implementation varies from FaceNet as we use the `ResNet-18` model from
        `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ as our feature extractor.
    """
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # get resnet model
        self.resnet = torchvision.models.resnet18(weights=None)
        
        self.fc_in_features = self.resnet.fc.in_features
        
        # remove the last layer of resnet18 (linear layer which is before avgpool layer)
        self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))

        # add linear layers to compare between the features of the two images
        self.fc = nn.Sequential(
            nn.Linear(self.fc_in_features * 2, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 1),
        )

        self.sigmoid = nn.Sigmoid()

        # initialize the weights
        self.resnet.apply(self.init_weights)
        self.fc.apply(self.init_weights)
        
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def forward_once(self, x):
        output = self.resnet(x)
        output = output.view(output.size()[0], -1)
        return output

    def forward(self, input1, input2, input3):
        # get two images' features
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        output3 = self.forward_once(input3)

        return output1, output2, output3

class APP_MATCHER(Dataset):
    def __init__(self, root, train, download=False):
        super(APP_MATCHER, self).__init__()

        if train==True:
            self.dataset = CustomImageDataset(Config.training_dir, transform=Config.transform, label_dir=Config.annotations_file)
        else:
            self.dataset = CustomImageDataset(Config.training_dir, transform=Config.transform, label_dir=Config.test_annotations_file)
        
        self.data = self.dataset.img_labels

        self.group_examples()

    def group_examples(self):
        """
            To ease the accessibility of data based on the class, we will use `group_examples` to group 
            examples based on class. 
        """

        # get the targets from dataset
        np_arr = np.array(self.dataset.img_labels['class_label'])
        
        # group examples based on class
        self.grouped_examples = {}
        for i in range(2):
            self.grouped_examples[i] = np.where((np_arr==i))[0]
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index):
        """
            For every example, we will select two images. There are two cases, 
            positive and negative examples. For positive examples, we will have two 
            images from the same class. For negative examples, we will have two images 
            from different classes.
        """

        # pick some random class for the first image
        selected_class = random.randint(0, Config.NUM_CLASSES-1)

        # pick a random index for the first image in the grouped indices based of the label
        # of the class
        random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1)
        
        # pick the index to get the first image
        index_1 = self.grouped_examples[selected_class][random_index_1]

        # get the first image
        image_1 = self.dataset.__getitem__(index_1)[0]

        # same class
        # pick a random index for the second image
        random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1)

        # ensure that the index of the second image isn't the same as the first image
        while random_index_2 == random_index_1:
            random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1)

        # pick the index to get the second image
        index_2 = self.grouped_examples[selected_class][random_index_2]

        # get the second image
        image_2 = self.dataset.__getitem__(index_2)[0]
        
        # different class=
        other_selected_class = random.randint(0, Config.NUM_CLASSES-1)

        # ensure that the class of the second image isn't the same as the first image
        while other_selected_class == selected_class:
            other_selected_class = random.randint(0, Config.NUM_CLASSES-1)

        # pick a random index for the second image in the grouped indices based of the label
        # of the class
        random_index_3 = random.randint(0, self.grouped_examples[other_selected_class].shape[0]-1)

        # pick the index to get the second image
        index_3 = self.grouped_examples[other_selected_class][random_index_3]

        # get the second image
        image_3 = self.dataset.__getitem__(index_3)[0]

        return image_1, image_2, image_3


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()

    criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)

    for batch_idx, (images_1, images_2, images_3) in enumerate(train_loader):
        images_1, images_2, images_3 = images_1.to(device), images_2.to(device), images_3.to(device)
        optimizer.zero_grad()
        outputs = model(images_1, images_2, images_3)
        loss = criterion(outputs[0],outputs[1], outputs[2])
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(images_1), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    criterion = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)

    with torch.no_grad():
        for (images_1, images_2, images_3) in test_loader:
            images_1, images_2, images_3 = images_1.to(device), images_2.to(device), images_3.to(device)
            outputs = model(images_1, images_2, images_3)
            test_loss += criterion(outputs[0],outputs[1], outputs[2]).sum().item()  # sum up batch loss

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    class Args(argparse.Namespace):
        batch_size=64
        test_batch_size=64
        epochs=20
        lr=0.005
        gamma=0.7
        no_cuda=False
        no_mps=True
        dry_run=False
        seed=1
        log_interval=10
        save_model=True

    args=Args()
    
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    use_mps = not args.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    train_dataset = APP_MATCHER('../data', train=True, download=True)
    # test_dataset = APP_MATCHER('../data', train=False)
    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    # test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

    model = SiameseNetwork().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        # test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "siamese_network.pt")

In [19]:
main()

