In [1]:
import torch
import math
import numpy as np
from torch import optim
from torch import Tensor
from torch import nn
from torch.nn import functional as F

import dlc_practical_prologue as prologue

In [2]:
N = 1000 # Number of data samples in training and test set

train_input, train_target, train_classes, \
    test_input, test_target, test_classes = prologue.generate_pair_sets(N)

print(train_input.shape)
print(train_target.shape)
print(train_classes.shape)

train_classes[:5]

torch.Size([1000, 2, 14, 14])
torch.Size([1000])
torch.Size([1000, 2])


tensor([[9, 3],
        [5, 4],
        [7, 4],
        [9, 6],
        [8, 8]])

In [3]:
def normalize(input, mean, std):
    input.sub_(mean).div_(std)
    
def process_data(img_input, classes, one_hot_classes=False):
    
    n_img = img_input.size(0) 
    img_input_1 = img_input[:,0,:,:].reshape(n_img, 1, 14, 14)
    img_input_2 = img_input[:,1,:,:].reshape(n_img, 1, 14, 14)
    
    img_classes_1 = prologue.convert_to_one_hot_labels(img_input_1, classes[:,0]) if one_hot_classes else classes[:,0]
    img_classes_2 = prologue.convert_to_one_hot_labels(img_input_2, classes[:,1]) if one_hot_classes else classes[:,1]
    
    img_classes_1.reshape(-1,1)
    img_classes_2.reshape(-1,1)
    
    return img_input_1, img_input_2, img_classes_1, img_classes_2

In [4]:
mean = train_input.mean(dim=(0,2,3), keepdim=True)
std = train_input.std(dim=(0,2,3), keepdim=True)

normalize(train_input, mean, std)
normalize(test_input, mean, std)

train_input_1, train_input_2, train_classes_1, train_classes_2 = process_data(train_input, train_classes)
test_input_1, test_input_2, test_classes_1, test_classes_2 = process_data(test_input, test_classes)

In [5]:
class DigitNet(nn.Module):
    def __init__(self, nb_hidden):
        super(DigitNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

In [6]:
class CompNet(torch.nn.Module):
    def __init__(self):
        super(CompNet, self).__init__()
        self.fc1 = nn.Linear(20, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 1)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = torch.sigmoid(x)
        return x

In [7]:
def train_model(model_digit, model_comp, 
                train_input_1, train_input_2, train_classes_1, train_classes_2, train_target, 
                criterion_digit=nn.CrossEntropyLoss(), criterion_comp=nn.BCELoss(), 
                mini_batch_size=25,nb_epochs=25, lr=1e-1):
    
    optimizer_digit = torch.optim.SGD(model_digit.parameters(), lr=lr)
    optimizer_comp = torch.optim.SGD(model_comp.parameters(), lr=lr)
    
    for e in range(nb_epochs):
        if e % 5 == 0:
            print("Epochs {}".format(e))
        for b in range(0, train_input.size(0), mini_batch_size):
            
            # digit classification 
            output_img_1 = model_digit(train_input_1.narrow(0, b, mini_batch_size))
            output_img_2 = model_digit(train_input_2.narrow(0, b, mini_batch_size))
            
            loss_img_1 = criterion_digit(output_img_1, train_classes_1.narrow(0, b, mini_batch_size))
            loss_img_2 = criterion_digit(output_img_2, train_classes_2.narrow(0, b, mini_batch_size))
            loss_img = loss_img_1 + loss_img_2
            
            output_comp = model_comp(torch.cat((output_img_1, output_img_2), 1))
            batch_target = train_target.narrow(0, b, mini_batch_size).reshape(-1,1).float()
            loss_comp = criterion_comp(output_comp, batch_target)
            
            loss = loss_img + loss_comp
            
            if b==0:
                print("loss = {}, loss_img = {}, loss_comp = {}".format(loss, loss_img, loss_comp))
                
            model_digit.zero_grad()
            model_comp.zero_grad()
            loss.backward()
            optimizer_digit.step()
            optimizer_comp.step()

In [8]:
model_digit = DigitNet(500)
model_comp = CompNet()

print(sum(p.numel() for p in model_digit.parameters() if p.requires_grad))
print(sum(p.numel() for p in model_comp.parameters() if p.requires_grad))
print("training...")

train_model(model_digit=model_digit, model_comp=model_comp,
            train_input_1=train_input_1, train_input_2=train_input_2,
            train_classes_1=train_classes_1, train_classes_2=train_classes_2, 
            train_target=train_target)

152326
3651
training...
Epochs 0
loss = 5.326606750488281, loss_img = 4.633459568023682, loss_comp = 0.6931474208831787
loss = 3.1775200366973877, loss_img = 2.463043212890625, loss_comp = 0.7144768238067627
loss = 1.6411175727844238, loss_img = 0.9218927025794983, loss_comp = 0.7192248702049255
loss = 1.3241312503814697, loss_img = 0.6552929878234863, loss_comp = 0.6688382625579834
loss = 1.100134253501892, loss_img = 0.38247573375701904, loss_comp = 0.717658519744873
Epochs 5
loss = 0.9456673860549927, loss_img = 0.2893448770046234, loss_comp = 0.6563225388526917
loss = 0.7134124636650085, loss_img = 0.1570490002632141, loss_comp = 0.5563634634017944
loss = 0.6869116425514221, loss_img = 0.07245473563671112, loss_comp = 0.6144568920135498
loss = 0.5992786884307861, loss_img = 0.06728833168745041, loss_comp = 0.5319903492927551
loss = 0.5215492844581604, loss_img = 0.032031022012233734, loss_comp = 0.48951828479766846
Epochs 10
loss = 0.5321062803268433, loss_img = 0.03115367330610752

In [9]:
def compute_nb_errors_siamese(model_digit, model_comp,
                              data_input_1, data_input_2, data_target, mini_batch_size=25):

    nb_data_errors = 0

    for b in range(0, data_input_1.size(0), mini_batch_size):
        output_img_1 = model_digit(data_input_1.narrow(0, b, mini_batch_size))
        output_img_2 = model_digit(data_input_2.narrow(0, b, mini_batch_size))
        
        output_comp = model_comp(torch.cat((output_img_1, output_img_2), 1))
        output_comp = torch.round(output_comp)
    
        for k in range(mini_batch_size):
            if data_target[b + k] != output_comp[k]:
                nb_data_errors = nb_data_errors + 1

    return nb_data_errors

In [10]:
def print_error_siamese(model_digit, model_comp, tr_input_1, tr_input_2, tr_target, te_input_1, te_input_2, te_target):
    print('train_error {:.02f}% test_error {:.02f}%'.format(
                compute_nb_errors_siamese(model_digit, model_comp, tr_input_1, tr_input_2, tr_target) / N * 100,
                compute_nb_errors_siamese(model_digit, model_comp, te_input_1, te_input_2, te_target) / N * 100))

In [11]:
print_error_siamese(model_digit, model_comp, train_input_1, train_input_2, train_target, test_input_1, test_input_2, test_target)

train_error 2.30% test_error 8.80%
