In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn.functional as F
import matplotlib.pylab as plt
import numpy as np
from time import time
from torch.utils.data import Dataset

## Prep Functions- show images

In [3]:
def plot_accuracy_loss(training_results):
    plt.subplot(2, 1, 1)
    plt.plot(training_results['training_loss'], 'r')
    plt.ylabel('loss')
    plt.title('training loss iterations')
    plt.subplot(2, 1, 2)
    plt.plot(training_results['validation_accuracy'])
    plt.ylabel('accuracy')
    plt.xlabel('epochs')
    plt.show()

def show_data(data_sample):
    plt.imshow(data_sample[0].numpy().reshape(28, 28), cmap='gray')
    plt.title('y = ' + str(data_sample[1]))
    plt.show()

def show_dataComp(data_sample,y):
    plt.subplot(1,2,1)
    plt.imshow(data_sample[0].numpy().reshape(28, 28), cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(data_sample[1].numpy().reshape(28, 28), cmap='gray')
    plt.title('y = ' + str(y))

    plt.show()

## Data Pairing

In [4]:
def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

In [5]:
def comparisonDataConsecutive(dataSet):
    indices = []
    #gets all the indices of the data obsv with same y from the dataset that is passed in
    for i in range(5):
        indlist = get_same_index(dataSet.targets, i)
        indices.append(indlist)
    subsets = []
    [subsets.append(torch.utils.data.Subset(dataSet, i)) for i in indices]
    comp = []
    for indi in range(len(indices) - 1):
        comp.append(int(min(len(subsets[indi]), len(subsets[indi + 1]))))
    tot1 = sum(comp)
    x = torch.zeros([tot1, 2, 28, 28], dtype=torch.float32)
    y = torch.zeros([tot1,1])
    # 1 for first pic greater, 0 for first pic less
    k = 0
    for i in range(len(subsets) - 1):
        for j in range(int(comp[i] / 2)):
            x[k][0] = subsets[i][j][0]
            x[k][1] = subsets[i + 1][j][0]
            y[k][0] = 0
            k += 1
        for j in range(int(comp[i] / 2), comp[i]):
            x[k][1] = subsets[i][j][0]
            x[k][0] = subsets[i + 1][j][0]
            y[k][0] = 1
            k += 1
    return x,y

In [6]:
def comparisonDataNonconsecutive(dataSet):
    indices = []
    # gets all the indices of the data obsv with same y from the dataset that is passed in
    for i in range(5):
        indlist = get_same_index(dataSet.targets, i)
        indices.append(indlist)
    subsets = []
    [subsets.append(torch.utils.data.Subset(dataSet, i)) for i in indices]
    comp = {}
    for i in range(len(subsets) - 2):
        for j in range(i + 2, len(subsets), 1):
            comp[(i, j)] = int(min(len(subsets[i]), len(subsets[j])))
    tot = sum(comp.values())
    x = torch.zeros([tot, 2, 28, 28], dtype=torch.float32)
    y = torch.zeros([tot, 1])
    k = 0
    for key, values in comp.items():
        for value in range(int(values / 2)):
            x[k][0] = subsets[key[0]][value][0]
            x[k][1] = subsets[key[1]][value][0]
            y[k][0] = 0
            k += 1
        for value in range(int(values / 2), values):
            x[k][0] = subsets[key[1]][value][0]
            x[k][1] = subsets[key[0]][value][0]
            y[k][0] = 1
            k += 1
    return x,y


In [9]:
class Trainsetcomp(Dataset):
    def __init__(self, x, y):
        self.len = (x.shape[0])
        self.x = x
        self.y = y

    def __getitem__(self, index):
        return (self.x[index], self.y[index])

    def __len__(self):
        return self.len

In [10]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),])

trainset = dsets.MNIST(root='./../data',
                            train=True,
                            download=True,
                            transform=transform)
valset = dsets.MNIST(root='./../data',
                            train=False,
                            download=True,
                            transform=transform)

In [11]:
x, y = comparisonDataConsecutive(trainset)
traindataComp = Trainsetcomp(x,y)
x,y = comparisonDataConsecutive(valset)
valdataComp = Trainsetcomp(x,y)

In [12]:
trainloader = torch.utils.data.DataLoader(traindataComp,
                                          batch_size=64,
                                          shuffle=True)
valloader = torch.utils.data.DataLoader(valdataComp,
                                          batch_size=64,
                                          shuffle=True)
x,y = comparisonDataNonconsecutive(valset)
testdata = Trainsetcomp(x,y)
testloader = torch.utils.data.DataLoader(testdata,
                                          batch_size=64,
                                          shuffle=True)

## Model

In [53]:
ind = 2 * 28 * 28
hiddendim = [256,128,64]
outd = 1
# 0 if first image is less than and 1 if frist image is greater than

In [74]:
class ModelFull(nn.Module):
    def __init__(self, ind,h1d,h2d,h3d,outd):
        super(ModelFull, self).__init__()
        self.lin1 = nn.Linear(ind, h1d)
        self.lin2 = nn.Linear(h1d, h2d)
        self.lin3 = nn.Linear(h2d, h3d)
        self.lin4 = nn.Linear(h3d, outd)
        self.activations = {'h1':[],'h2':[],'h3':[]}

    def forward(self, x, recActivations = False):
        if recActivations:
            x1 = torch.relu(self.lin1(x))
            self.activations['h1'].append(x1)
            x2 = torch.relu(self.lin2(x1))
            self.activations['h2'].append(x2)
            x3 = torch.relu(self.lin3(x2))
            self.activations['h3'].append(x3)
            x4 = torch.sigmoid(self.lin4(x3))
            return x4
        else:
            x1 = torch.relu(self.lin1(x))
            x2 = torch.relu(self.lin2(x1))
            x3 = torch.relu(self.lin3(x2))
            x4 = torch.sigmoid(self.lin4(x3))
            return x4

In [75]:
def train(model, criterion, optimizer, epochs = 30):
    lossList = []
    #time0 = time()
    for i in range(epochs):
        runningLoss = 0
        for x, y in trainloader:
            optimizer.zero_grad()
            yhat = model(x.view(-1, 2 * 28 *28))
            #print(yhat.shape)
            #print(y.shape)
            loss = criterion(yhat, y)
            loss.backward()
            optimizer.step()
            runningLoss += loss.item()
        print('epoch ', i, ' loss: ', str(runningLoss / len(traindataComp)))
        lossList.append(runningLoss / len(traindataComp))

In [76]:
model = ModelFull(ind,hiddendim[0],hiddendim[1],hiddendim[2], outd)

In [77]:
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)
results = train(model,criterion,optimizer,25)

epoch  0  loss:  0.008575960630377544
epoch  1  loss:  0.003752651454524602
epoch  2  loss:  0.002658630019780359
epoch  3  loss:  0.0015895875997655416
epoch  4  loss:  0.0012948909899556458
epoch  5  loss:  0.001110205670841418
epoch  6  loss:  0.0015078286783617371
epoch  7  loss:  0.0009902292535938935
epoch  8  loss:  0.0008478872185409142
epoch  9  loss:  0.0007440199225985267
epoch  10  loss:  0.0006870708614965823
epoch  11  loss:  0.0006146428428192812
epoch  12  loss:  0.0005412911956966704
epoch  13  loss:  0.0005852784405366366
epoch  14  loss:  0.0004439410185467331
epoch  15  loss:  0.00036050798534829095
epoch  16  loss:  0.0003289471525574093
epoch  17  loss:  0.0002800824272379851
epoch  18  loss:  0.00025627378670592307
epoch  19  loss:  0.00019368897315860068
epoch  20  loss:  0.00020857671430596873
epoch  21  loss:  0.0001176749896105278
epoch  22  loss:  0.0005636333829042921
epoch  23  loss:  0.0005239267853290685
epoch  24  loss:  0.0002438008596358646


In [78]:
## Consecutive test set
totcount = 0
correctcount = 0
for x,y in valloader:
    x = x.view(-1, 2 * 28 *28)
    with torch.no_grad():
        yhat = model(x)
    ones = torch.ones(yhat.shape)
    yhat = torch.where(yhat>.9, ones, yhat)
    z = torch.zeros(yhat.shape)
    yhat = torch.where(yhat<0.1, z, yhat)
    for i,j in zip(yhat,y):
        if i[0] == j[0]:
            correctcount+=1
        totcount+=1
print(correctcount)
print(totcount)
#print(len(dataPairing.valdataComp))
print('valset accuracy: ', correctcount/totcount)

3858
4004
valset accuracy:  0.9635364635364635


In [79]:
## non consecutive test set
for x,y in testloader:
    x = x.view(-1, 2 * 28 *28)
    with torch.no_grad():
        yhat = model(x)
    ones = torch.ones(yhat.shape)
    yhat = torch.where(yhat>.5, ones, yhat)
    z = torch.zeros(yhat.shape)
    yhat = torch.where(yhat<0.5, z, yhat)
    wrongC = 0
    for i in range(len(y)):
        #print(i)
        if yhat[i][0] == y[i][0]:
            correctcount+=1
            #p = torch.reshape(x[i], (2,28,28))
            #show_dataComp(p, y[i])
            
        else:
            #if wrongC < 2:
                #p = torch.reshape(x[i], (2,28,28))
                #show_dataComp(p, yhat[i])
            wrongC += 1
        totcount+=1
print(correctcount)
print(totcount)
print('test set accuracy: ', correctcount/totcount)

9167
9918
test set accuracy:  0.9242790885259125


## Looking at activations

look activations of the numbers of equal step
<br>
hyp: activations of the comparison pairs should be more similar if the the difference in the comparison is the same

In [88]:
x = traindataComp[0][0].view(-1, 2 * 28 *28)
yhat = model.forward(x,recActivations=True)
yhat[0].round() == traindataComp[0][1]

tensor([True])

In [90]:
model.activations['h1'][0].shape

torch.Size([1, 256])