In [35]:
import torch
import math
import numpy as np
from torch import optim
from torch import Tensor
from torch import nn
from torch.nn import functional as F

import dlc_practical_prologue as prologue

In [3]:
N = 1000 # Number of data samples in training and test set

train_input, train_target, train_classes, \
    test_input, test_target, test_classes = prologue.generate_pair_sets(N)

print(train_input.shape)
print(train_target.shape)
print(train_classes.shape)

train_classes[:5]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/mnist/MNIST\raw\train-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/mnist/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/mnist/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw
Processing...
Done!
torch.Size([1000, 2, 14, 14])
torch.Size([1000])
torch.Size([1000, 2])


tensor([[9, 3],
        [5, 4],
        [7, 4],
        [9, 6],
        [8, 8]])

In [110]:
class DigitNetSingleOutput(nn.Module):
    def __init__(self, nb_hidden):
        super(DigitNetSingleOutput, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)
        #self.fc3 = nn.Linear(10,1)
        
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = F.relu(self.fc2(x))
        #x = (self.fc3(x))
        return x



In [111]:
def normalize(input, mean, std):
    input.sub_(mean).div_(std)
    return 
    
def process_data(img_input,classes,one_hot_classes=False):
    
    n_img = img_input.size(0) 
    img_input_1 = img_input[:,0,:,:].reshape(n_img, 1, 14, 14)
    img_input_2 = img_input[:,1,:,:].reshape(n_img, 1, 14, 14)
    
    img_classes_1 = prologue.convert_to_one_hot_labels(img_input_1,classes[:,0]) if one_hot_classes else classes[:,0]
    img_classes_2 = prologue.convert_to_one_hot_labels(img_input_2,classes[:,1]) if one_hot_classes else classes[:,1]
    
    img_classes_1.reshape(-1,1)
    img_classes_2.reshape(-1,1)
    
    return img_input_1,img_input_2,img_classes_1,img_classes_2
    
    

In [112]:
mean = train_input.mean(dim=(0,2,3),keepdim=True)
std = train_input.std(dim=(0,2,3),keepdim=True)

normalize(train_input, mean, std)
normalize(test_input, mean, std)

train_input_1,train_input_2,train_classes_1,train_classes_2 = process_data(train_input,train_classes)
test_input_1,test_input_2,test_classes_1,test_classes_2 = process_data(test_input,test_classes)




In [128]:
'''
Function to compare two batches of 1-hot encoded digits in [0,9]
returns array of 0s and 1s encoding the comparaison
'''
def get_targets(output_1,output_2):
    result = torch.zeros(output_1.size(0))
    digits_1 = torch.argmax(output_1,dim=1,keepdim=True)
    digits_2 = torch.argmax(output_2,dim=1,keepdim=True)

    output_targets = torch.ge(digits_2,digits_1).float()
    
    return output_targets.reshape(-1,1)
    

In [132]:
'''
Function to train a siamese model.
Since this is used to compare two images as input, we use the contrastive loss rather than the cross entropy.
'''
def train_siamese_model(model, train_input_1,train_input_2, train_classes_1,train_classes_2, train_target, mini_batch_size=25, 
                nb_epochs=25, criterion_digit=nn.CrossEntropyLoss(),criterion_comp=nn.BCELoss() ,lr=1e-1):
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
   
    
    for e in range(nb_epochs):
        for b in range(0, train_input_1.size(0), mini_batch_size):
            
            # digit classification 
            output_1 = model(train_input_1.narrow(0, b, mini_batch_size))
            output_2 = model(train_input_2.narrow(0, b, mini_batch_size))
            
            loss_img_1 = criterion_digit(output_1,train_classes_1.narrow(0,b,mini_batch_size))
            loss_img_2 = criterion_digit(output_2,train_classes_2.narrow(0,b,mini_batch_size))
            loss_img = loss_img_1 + loss_img_2 
            
            # digit comparaison
            output_target = get_targets(output_1,output_2)
            
            batch_target = train_target.narrow(0, b, mini_batch_size).reshape(-1,1).float()
            
            loss_comp = criterion_comp(output_target,batch_target)
            
            loss = loss_img + loss_comp
            
            model.zero_grad()
            loss.backward()
            optimizer.step()

In [115]:
def compute_nb_errors_siamese(model, data_input_1, data_input_2, data_target, mini_batch_size=25):

    nb_data_errors = 0

    for b in range(0, data_input_1.size(0), mini_batch_size):
        output_1 = model(data_input_1.narrow(0, b, mini_batch_size))
        output_2 = model(data_input_2.narrow(0, b, mini_batch_size))
        predicted_targets = get_targets(output_1,output_2)
        #output = torch.cat((output_1, output_2), 1)
        #_, predicted_classes = torch.max(output, 1)
        for k in range(mini_batch_size):
            if data_target[b + k] != predicted_targets[k]:
                nb_data_errors = nb_data_errors + 1

    return nb_data_errors

In [116]:
def print_error_siamese(model, tr_input_1, tr_input_2, tr_target, te_input_1, te_input_2, te_target):
    print('train_error {:.02f}% test_error {:.02f}%'.format(
                compute_nb_errors_siamese(model, tr_input_1, tr_input_2, tr_target) / N * 100,
                compute_nb_errors_siamese(model, te_input_1, te_input_2, te_target) / N * 100)) 

In [85]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

152326

In [133]:
model = DigitNetSingleOutput(500)
train_siamese_model(model, train_input_1, train_input_2,train_classes_1,train_classes_2, train_target)

In [134]:
print_error_siamese(model, train_input_1, train_input_2, train_target, test_input_1, test_input_2, test_target)

train_error 33.50% test_error 34.40%
