In [16]:
# -*- encoding: utf-8 -*-
import argparse
import torch
import torchvision.datasets as dsets
import random
import numpy as np
import time
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torchvision import transforms
import pickle
import torch
import torch.nn as nn


In [17]:
# Preprocess data
class Dataset(object):

    def __init__(self, x0, x1, label):
        self.size = label.shape[0]
        self.x0 = torch.from_numpy(x0)
        self.x1 = torch.from_numpy(x1)
        self.label = torch.from_numpy(label)

    def __getitem__(self, index):
        return (self.x0[index],
                self.x1[index],
                self.label[index])

    def __len__(self):
        return self.size

def create_pairs(data, digit_indices):
    x0_data = []
    x1_data = []
    label = []
    n = min([len(digit_indices[d]) for d in range(10)]) - 1
    for d in range(10):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            x0_data.append(data[z1] / 255.)
            x1_data.append(data[z2] / 255.)
            label.append(1)
            inc = random.randrange(1, 10)
            dn = (d + inc) % 10
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            x0_data.append(data[z1] / 255.)
            x1_data.append(data[z2] / 255.)
            label.append(0)

    x0_data = np.array(x0_data, dtype=np.float32)
    x0_data = x0_data.reshape([-1, 1, 28, 28])
    x1_data = np.array(x1_data, dtype=np.float32)
    x1_data = x1_data.reshape([-1, 1, 28, 28])
    label = np.array(label, dtype=np.int32)
    return x0_data, x1_data, label


def create_iterator(data, label, batchsize, shuffle=False):
    digit_indices = [np.where(label == i)[0] for i in range(10)]
    x0, x1, label = create_pairs(data, digit_indices)
    ret = Dataset(x0, x1, label)
    return ret


# Loss Function: Contrastive Loss Function 

In [18]:
def contrastive_loss_function(self, x0, x1, y, margin=1.0):
    # euclidian distance
    diff = x0 - x1
    dist_sq = torch.sum(torch.pow(diff, 2), 1)
    dist = torch.sqrt(dist_sq)

    mdist = self.margin - dist
    dist = torch.clamp(mdist, min=0.0)
    loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
    loss = torch.sum(loss) / 2.0 / x0.size()[0]
    return loss

# Defining Siamese Network Architecture

In [35]:
class SiameseNetwork(nn.Module):
    def __init__(self,flag_kaf=False):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.MaxPool2d(2, stride=2))
        if(flag_kaf):
            self.fc1 = nn.Sequential(
                nn.Linear(50 * 4 * 4, 500),
                nn.ReLU(inplace=True),
                nn.Linear(250, 10),
                nn.Linear(500,10),
                nn.Linear(10, 2))
        else:
            self.fc1 = nn.Sequential(
                nn.Linear(50 * 4 * 4, 500),
                nn.ReLU(inplace=True),
                nn.Linear(250, 10),
                nn.Linear(500,10),
                nn.Linear(10, 2))

    def forward_once(self, x):
        print ("came inside forward once")
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        print (output.shape)
        return output

    def forward(self, input1, input2):
        print ("checking input")
        print (input1.shape)
        print (input2.shape)
        print ("finished input")
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        print (output1.shape)
        print (output2.shape)
        return output1, output2


In [36]:
def plot_mnist(numpy_all, numpy_labels,name="../Results/embeddings_plot.png"):
        c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
             '#ff00ff', '#990000', '#999900', '#009900', '#009999']

        for i in range(10):
            f = numpy_all[np.where(numpy_labels == i)]
            plt.plot(f[:, 0], f[:, 1], '.', c=c[i])
        plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
        plt.savefig(name)

def plot_loss(train_loss,name="../Results/train_loss.png"):
    plt.gca().cla()
    plt.plot(train_loss, label="train loss")
    plt.legend()
    plt.draw()
    plt.savefig(name)
    plt.gca().clear()

In [37]:
batchsize=16
train = dsets.MNIST(root='../data/',train=True,download=True)
test = dsets.MNIST(root='../data/',train=False,transform=transforms.Compose([transforms.ToTensor(),]))
train_iter = create_iterator(train.train_data.numpy(),train.train_labels.numpy(),batchsize)
# model
model = SiameseNetwork()
learning_rate = 0.01
momentum = 0.9
# Loss and Optimizer
criterion = contrastive_loss_function
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                            momentum=momentum)

#kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
kwargs = {}
train_loader = torch.utils.data.DataLoader(train_iter,batch_size=batchsize, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(test,batch_size=batchsize, shuffle=True, **kwargs)


In [38]:
train_loss = []
running_loss=0.0
epochs =10
for epoch in range(epochs):
    print('Train Epoch:'+str(epoch)+"------------------")
    for batch_idx, (x0, x1, labels) in enumerate(train_loader):
        labels = labels.float()
#         if args.cuda:
#             x0, x1, labels = x0.cuda(), x1.cuda(), labels.cuda()
        x0, x1, labels = Variable(x0), Variable(x1), Variable(labels)
        output1, output2 = model(x0, x1)
        loss = criterion(output1, output2, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        if batch_idx % args.batchsize == 0:
            print('Batch id: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx, batch_idx * len(labels), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    torch.save(model.state_dict(), './weights/model-epoch-%s.pth' % epoch)
print ("finished_training")
return train_loss

Train Epoch:0------------------
checking input
torch.Size([16, 1, 28, 28])
torch.Size([16, 1, 28, 28])
finished input
came inside forward once


RuntimeError: size mismatch, m1: [16 x 500], m2: [250 x 10] at /Users/distiller/project/conda/conda-bld/pytorch_1556653492823/work/aten/src/TH/generic/THTensorMath.cpp:961