In [10]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
from torch.optim import Adam 

from statistics import mean
from random import choices, choice, seed
import matplotlib.pyplot as plt
import numpy as np

# Siamese Neural Network for object classification
We use a resnet as the backbone of our Siamese Neural Network (SNN) with a simple feed 
forward neural net on top as a classifier. Our SNN takes two images and outputs their 
similarity where numbers close to 0 indicate similar images and numbers close to 1 indicate 
dissimilar images. 

In [11]:
class SiameseNN(nn.Module): 
    def __init__(self, backbone=None, distance_dim=4096): 
        super(SiameseNN, self).__init__()
        self.distance_dim = distance_dim
        if backbone is None: 
            self.backbone = resnet18(weights=ResNet18_Weights.DEFAULT)
        else: 
            self.backbone = backbone
        self.sim = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(1000, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Dropout(p=0.5),
            nn.Linear(512, 64),
            nn.BatchNorm1d(64),
            nn.Sigmoid(),
            nn.Dropout(p=0.5),

            nn.Linear(64, 1),
            nn.Sigmoid(),
        )

    def forward(self, base_im, cmp_im): 
        out_1 = self.backbone(base_im)
        out_2 = self.backbone(cmp_im)
        return self.sim(torch.abs(out_1 - out_2))

# Triplet Dataset and Loss
We will be training our Siamese Neural Net on the CIFAR100 dataset. Each training sample 
consists of three images. Two images will be of the same class called the anchor and 
positive image, and the last will be from a different class called the negative image. 
Denote the anchor image by $a$, the positive image by $p$, and the negative image by $n$. 
We also set a hyperparameter $m$ as the margin
Our loss is given by: 

$$L(a, n, p) = \max(SNN(a, p) - SNN(a, n) + m, 0)$$

## Observations about the loss
- The loss is minimized when the SNN predicts the anchor and positive image as similar 
- The loss is minimized when the SNN predicts the anchor and negative image as dissimilar
- The loss is 0 when the difference in similarity is greater than the margin



In [12]:
class Triplet_Loss: 
    def __init__(self, margin) -> None:
        self.margin = margin

    def __call__(self, dist_same, dist_diff):
        return torch.mean(torch.clamp(dist_same - dist_diff + self.margin,min=0.0))

# Training and Validation
The CIFAR100 dataset has 100 different classes. We will take a subset of these classes
and use them to train the model. We will test how well the model generalizes by using the 
rest of the classes to validate. Ideally we will see that the model can tell that images 
of the same class are similar, on classes that weren't seen in the training set. 

## Specifics of the Dataset
We create a sample in the training set by the following procedure: 
1. split the 100 different classes for training or validation 
2. sample a class from the list of training classes
3. sample a class from the list of training classes which is not the class sampled in (2)
4. randomly sample two images from class chosen in (2). These are the anchor and positive image. 
5. randomly sample an image from the class chosen in (3). This is the negative image. 

In [13]:
class Triplet_Dataset(torch.utils.data.Dataset): 
    def __init__(self, train_classes, num_samples, seed, train=True) -> None:
        super().__init__()
        
        transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        self.fullset = torchvision.datasets.CIFAR100(root='../data/processed', train=True,
                                                    download=True, transform=transform)
        
        self.train_classes = train_classes
        self.num_samples = num_samples
        self.seed = seed
        
        self._create_dataset(train=train)

    def __len__(self): 
        return len(self.anchors)
    
    def __getitem__(self, idx): 
        return self.anchors[idx, ...], self.im_positive[idx, ...], self.im_negative[idx, ...]
    
    def create_label_map(self): 
        seed(self.seed)
        labels = {}
        for i in range(len(self.fullset)): 
            if self.fullset[i][1] in labels: 
                labels[self.fullset[i][1]].append(i)
            else: 
                labels[self.fullset[i][1]] = [i]
        return labels

    def train_test_split(self): 
        train_classes = choices(list(range(100)), k=self.train_classes)
        valid_classes = [i for i in range(100) if i not in train_classes]
        return train_classes, valid_classes

    def _create_dataset(self, train=True): 
        anchors = torch.zeros((self.num_samples, 3, 32, 32), dtype=torch.float)
        im_positive = torch.zeros((self.num_samples, 3, 32, 32), dtype=torch.float)
        im_negative = torch.zeros((self.num_samples, 3, 32, 32), dtype=torch.float)
        train_classes, valid_classes = self.train_test_split()
        label_map = self.create_label_map()
        potential_classes = train_classes
        if not train: 
            potential_classes = valid_classes
        for idx in range(self.num_samples): 
            true_class = choice(potential_classes)
            false_class = choice(potential_classes)
            while false_class == true_class: 
                false_class = choice(potential_classes)
            anchor_idx = choice(label_map[true_class])
            pos_idx = choice(label_map[true_class])
            neg_idx = choice(label_map[false_class])
            anchors[idx, :, :, :] = self.fullset[anchor_idx][0]
            im_positive[idx, :, :, :] = self.fullset[pos_idx][0]
            im_negative[idx, :, :, :] = self.fullset[neg_idx][0]
        self.anchors, self.im_positive, self.im_negative = anchors, im_positive, im_negative

In [14]:
def eval(loader, model, loss):
    same_score = []
    diff_score = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    with torch.no_grad(): 
        for anchor, pos, neg in loader: 
            anchor, pos, neg = anchor.to(device), pos.to(device), neg.to(device)
            dist_pos = model(anchor, pos)
            dist_neg = model(anchor, neg)
            
            same = dist_pos.reshape(-1).tolist()
            same_score.extend(same)

            diff = dist_neg.reshape(-1).tolist()
            diff_score.extend(diff)
    return mean(same_score), mean(diff_score)

In [None]:
config = {
        'lr': .00001, 
        'batchsize': 32, 
        'betas': (.9, .99), 
        'distance_dim': 4096, 
        'epochs': 100,
        'margin': .9,
        'train_classes': 75, 
        'train_samples': 50_000, 
        'val_samples': 10_000
    }

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'using gpu: {torch.cuda.is_available()}')

model = SiameseNN(distance_dim=config['distance_dim']).to(device)
optimizer = Adam(params=model.parameters(), lr=config['lr'], betas=config['betas'])
    
    
trainset = Triplet_Dataset(config['train_classes'], config['train_samples'], 0, train=True)
valset = Triplet_Dataset(config['train_classes'], config['val_samples'], 0, train=False)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=config['batchsize'],
                                        shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(valset, batch_size=config['batchsize'],
                                        shuffle=True, num_workers=8)

loss = Triplet_Loss(margin=config['margin'])

model.train()
for e in range(config['epochs']): 
    errs = []
    same_score = []
    diff_score = []
    acc = 0
    model.train()
    for anchor, pos, neg  in train_loader: 
        anchor, pos, neg = anchor.to(device), pos.to(device), neg.to(device)
        dist_pos = model(anchor, pos)
        dist_neg = model(anchor, neg)
        
        err = loss(dist_pos, dist_neg)

        same = dist_pos.reshape(-1).tolist()
        same_score.extend(same)

        diff = dist_neg.reshape(-1).tolist()
        diff_score.extend(diff)

        err.backward()
        optimizer.step()
        optimizer.zero_grad()
        errs.append(err.item())
    val_same_mean, val_diff_mean = eval(val_loader, model, loss)
    print(f"""epoch: {e}, train_error: {round(mean(errs), 4)}, train_dist_pos: {round(mean(same_score), 3)},\
 train_dist_neg: {round(mean(diff_score), 3)}, val_dist_pos: {round(val_same_mean, 3)},\
 val_dist_neg: {round(val_diff_mean, 3)}""")

using gpu: True
Files already downloaded and verified
Files already downloaded and verified
epoch: 0, train_error: 0.9005, train_dist_pos: 0.463, train_dist_neg: 0.462, val_dist_pos: 0.46, val_dist_neg: 0.461
epoch: 1, train_error: 0.8975, train_dist_pos: 0.466, train_dist_neg: 0.468, val_dist_pos: 0.494, val_dist_neg: 0.481
epoch: 2, train_error: 0.8557, train_dist_pos: 0.456, train_dist_neg: 0.5, val_dist_pos: 0.475, val_dist_neg: 0.455
epoch: 3, train_error: 0.7789, train_dist_pos: 0.432, train_dist_neg: 0.553, val_dist_pos: 0.557, val_dist_neg: 0.525
epoch: 4, train_error: 0.7057, train_dist_pos: 0.375, train_dist_neg: 0.569, val_dist_pos: 0.365, val_dist_neg: 0.334
epoch: 5, train_error: 0.6291, train_dist_pos: 0.319, train_dist_neg: 0.591, val_dist_pos: 0.415, val_dist_neg: 0.389
epoch: 6, train_error: 0.5844, train_dist_pos: 0.285, train_dist_neg: 0.605, val_dist_pos: 0.263, val_dist_neg: 0.244
epoch: 7, train_error: 0.5676, train_dist_pos: 0.28, train_dist_neg: 0.62, val_dist_p

KeyboardInterrupt: 

# Issues
The main issure is generalization. The training loss is decreasing, but the validation loss is not. We hope that the metric between two images generalizes to classes we have never seen before but this is certainly not the case. 

## Observations
- validation loss fluctuates wildly (maybe weights are changing too rapidly)

## Things I have tried
- playing around with the learning rate. 
- Tried different loss functions (binary cross entropy and contrastive loss)
- Increased data (CIFAR10 to CIFAR100 and increased number of samples to 50_000)
- various batchsizes (64, 32, 16)
