Imports

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


import math
import numpy as np

import matplotlib.pyplot as plt

Device

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

**Custom Dataset:**

>    Input: dataset (Cifar,MNIST,etc.)  
     return: dataset with pairs of images and label same/different

In [None]:
class siameseDataset(Dataset):
    def __init__(self, dataset):
        # original datasets
        self.trainset = dataset
        self.testset = None # later
        self.len = len(self.trainset)
       
        # original trainloaders
        self.trainloader1 = DataLoader(self.trainset, batch_size=1, shuffle=True)
        self.trainloader2 = DataLoader(self.trainset, batch_size=1, shuffle=True)
        self.reset_iterators() 

        
    def __len__(self):
        # same as the legth of the original dataset
        return self.len

    def __getitem__(self, idx):

        if self.counter == self.len:
          self.reset_iterators()
        else:
          self.counter +=1

        img1, label1 = next(self.iter1)
        img2, label2 = next(self.iter2)

        # remove the batch dim
        img1 = img1[0]
        img2 = img2[0] 

        output = torch.stack([img1,img2])

        if label1==label2:
          return output,1
        else:
          return output,0

    def reset_iterators(self):
        # original iterator
        self.iter1 = iter(self.trainloader1)
        self.iter2 = iter(self.trainloader2)
        self.counter = 0

**Embedding network:**

> From a CIFAR image to 10 dim vector

In [None]:
class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        # fully connected layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        
        # CNNs
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        # flatten
        x = x.view(-1, self.num_flat_features(x))
        
        # FCs
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

**Siamese network:**

> From 2 images to 2 embeddings , while using the same EmbeddingNet

In [None]:
class SiameseNet(nn.Module):
    def __init__(self, embedding_net):
        super(SiameseNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, inputs):
        
        # deals with batch
        x1 = inputs[:,0]
        x2 = inputs[:,1]

        # predict embeddings 
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)

        return output1, output2

    
    def get_embedding(self, x):
        return self.embedding_net(x)

**Dataset:**

In [None]:
# transformations
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# CIFAR10 dataset
cifar = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Siamese on CIFAR10 -  dataset $ dataloader 
ds = siameseDataset(cifar)
dl = DataLoader(ds,batch_size=32)

Files already downloaded and verified


CIFAR's 10 different classes:

In [None]:
labels = {}  
labels[0] = 'airplane'
labels[1] = 'car'
labels[2] = 'bird'
labels[3] = 'cat'
labels[4] = 'deer'
labels[5] = 'dog'
labels[6] = 'frog'
labels[7] = 'horse'
labels[8] = 'ship'
labels[9] = 'truck' 

Sanity check:

In [None]:
inputs,label = iter(dl).next()
CNN()(inputs[0]),CNN()(inputs[1]), SiameseNet(CNN())(inputs)
'Done!'

'Done!'

<br><br><br><br><br><br>
**Contrast Loss**:

$L_{contrast} = \frac{1}{2}(1-Y)D_w^2 + \frac{1}{2}(Y) {max(0,m-D_w)}^2$

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, emb1,emb2,label):
        
        # print('emb1',emb1.shape)
        # print('emb2',emb1.shape)

        label = label.float()
        # print('label',label)

        # euclidean distance
        distance = (emb1-emb2).pow(2).sum(1).pow(0.5)
        # print('distance',distance)

        # contrastive loss
        left = (1 - label) * distance.pow(2)
        right =    (label) * F.relu(self.margin-distance+self.eps).pow(2) # relu is like max(0,_)

        # print('left',left)
        # print('right',right)

        loss = 0.5 * (left + right)
        # print('loss',loss)
        # print('loss mean',loss.mean())
        return loss.mean()


In [None]:
ContrastiveLoss()(torch.Tensor(64,10),torch.Tensor(64,10),torch.zeros(1))

tensor(nan)

<br><br><br><br><br>

**Train function:**

In [None]:
def train_model(model, trainset, EPOCHS=10, BATCH_SIZE=64, DEVICDE=DEVICE, VERBOSE=False,BASIC_VERBOSE=True):

  if BASIC_VERBOSE:
    print('Training - began...')

  # define loss function
  criterion = ContrastiveLoss() 

  # define the optimizer
  optimizer = torch.optim.Adam(model.parameters())

  # dataloader
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

  # model to device
  model = model.to(DEVICE)

  # training loop
  Loss = []
  for epoch in range(EPOCHS):  

      running_loss = 0.0
      for i, data in enumerate(trainloader, 0):
          # get the inputs
          inputs, labels = data
          
          inputs = inputs.to(DEVICE) 
          labels = labels.to(DEVICE) 

          # zero the parameter gradients
          optimizer.zero_grad()

          # forward + backward + optimize
          outputs = model(inputs)
          loss = criterion(*outputs, labels)

          if torch.isnan(loss):
              print('\n Got Loss=NaN when i=',i,'\n Please find returns: loss,outputs,data,model')
              return loss,outputs,data,model

          loss.backward()
          optimizer.step()

          running_loss += loss.item() 
          # # print statistics
          if VERBOSE:
            if i % 200 == 0:    
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0

      Loss.append(running_loss/len(trainloader))

  if BASIC_VERBOSE:
    print('Training - Done!')
    print('loss len: ', len(Loss), '\tfinal loss: ', Loss[-1])
  return Loss

In [None]:
loss,outputs,data,model = train_model(model=SiameseNet(CNN()),trainset=ds,VERBOSE=True)

Training - began...
[1,     1] loss: 0.001
[1,   201] loss: 0.045
[1,   401] loss: 0.045

 Got Loss=NaN when i= 536 
 Please find returns: loss,outputs,data,model


Here is a Contrast Loss from the web:  
https://github.com/adambielski/siamese-triplet/blob/master/losses.py

But it's also NaN

In [None]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    """

    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        distances = (output2 - output1).pow(2).sum(1)  # squared distances
        losses = 0.5 * (target.float() * distances +
                        (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        return losses.mean() if size_average else losses.sum()

In [None]:
ContrastiveLoss()(torch.Tensor(64,10),torch.Tensor(64,10),torch.zeros(1))