In [3]:
import torch
import math
import numpy as np
from torch import optim
from torch import Tensor
from torch import nn
from torch.nn import functional as F

import dlc_practical_prologue as prologue

In [4]:
torch.round(Tensor([0.5]))

tensor([0.])

In [49]:

index_t = Tensor([[0],[1],[1],[0],[1]])
print(index_t.shape)
def encode_targets(target):
    n = target.size(0)
    result = torch.zeros((n,2))
    return result.scatter(1,target.reshape(n,1).long(),1)
#test
print(encode_targets(index_t))

torch.Size([5, 1])
tensor([[1., 0.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [0., 1.]])


In [43]:
N = 1000 # Number of data samples in training and test set

train_input, train_target, train_classes, \
    test_input, test_target, test_classes = prologue.generate_pair_sets(N)

# use 1-hot encoding for targets
train_target = encode_targets(train_target)
test_target = encode_targets(test_target)

print(train_input.shape)
print(train_target.shape)
print(train_classes.shape)

train_classes[:5]

torch.Size([1000, 2, 14, 14])
torch.Size([1000])
torch.Size([1000, 2])


tensor([[5, 7],
        [2, 4],
        [5, 3],
        [4, 9],
        [1, 2]])

In [45]:
def normalize(input, mean, std):
    input.sub_(mean).div_(std)
    
def process_data(img_input, classes, one_hot_classes=False):
    
    n_img = img_input.size(0) 
    img_input_1 = img_input[:,0,:,:].reshape(n_img, 1, 14, 14)
    img_input_2 = img_input[:,1,:,:].reshape(n_img, 1, 14, 14)
    
    img_classes_1 = prologue.convert_to_one_hot_labels(img_input_1, classes[:,0]) if one_hot_classes else classes[:,0]
    img_classes_2 = prologue.convert_to_one_hot_labels(img_input_2, classes[:,1]) if one_hot_classes else classes[:,1]
    
    img_classes_1.reshape(-1,1)
    img_classes_2.reshape(-1,1)
    
    return img_input_1, img_input_2, img_classes_1, img_classes_2

In [46]:
mean = train_input.mean(dim=(0,2,3), keepdim=True)
std = train_input.std(dim=(0,2,3), keepdim=True)

normalize(train_input, mean, std)
normalize(test_input, mean, std)

train_input_1, train_input_2, train_classes_1, train_classes_2 = process_data(train_input, train_classes)
test_input_1, test_input_2, test_classes_1, test_classes_2 = process_data(test_input, test_classes)

In [47]:
class DigitNet(nn.Module):
    def __init__(self, nb_hidden):
        super(DigitNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

In [48]:
class CompNet(torch.nn.Module):
    def __init__(self, digitNet):
        super(CompNet, self).__init__()
        self.digitNet = digitNet
        self.fc1 = nn.Linear(20, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 2)
        
    def forward(self, x1, x2,train=True):
        x1 = self.digitNet.forward(x1)
        x2 = self.digitNet.forward(x2)
        x = torch.cat((x1, x2), 1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,p=0.25,training=train)
        x = F.relu(self.fc2(x))
        x = F.dropout(x,p=0.25,training=train)
        x = F.relu(self.fc3(x))
        x = torch.sigmoid(x)
        return x

In [50]:
train_target = encode_targets(train_target)
print(train_target)

tensor([[0., 1.],
        [0., 1.],
        [1., 0.],
        ...,
        [1., 0.],
        [1., 0.],
        [0., 1.]])


In [57]:
test_target = encode_targets(test_target)

In [53]:
def train_model(model_digit, model_comp, 
                train_input_1, train_input_2, train_classes_1, train_classes_2, train_target, 
                criterion_digit=nn.CrossEntropyLoss(), criterion_comp=nn.BCELoss(), 
                mini_batch_size=25,nb_epochs=50, lr=1e-1):
    
    optimizer_comp = torch.optim.SGD(model_comp.parameters(), lr=lr)
    
    for e in range(nb_epochs):
        if e % 5 == 0:
            print("Epochs {}".format(e))
        for b in range(0, train_input.size(0), mini_batch_size):
            
            # digit classification 
            output_img_1 = model_digit(train_input_1.narrow(0, b, mini_batch_size))
            output_img_2 = model_digit(train_input_2.narrow(0, b, mini_batch_size))
            
            loss_img_1 = criterion_digit(output_img_1, train_classes_1.narrow(0, b, mini_batch_size))
            loss_img_2 = criterion_digit(output_img_2, train_classes_2.narrow(0, b, mini_batch_size))
            loss_img = loss_img_1 + loss_img_2
            
            output_comp = model_comp(train_input_1.narrow(0, b, mini_batch_size), train_input_2.narrow(0, b, mini_batch_size))
            #batch_target = train_target.narrow(0, b, mini_batch_size).reshape(-1,1).float()
            batch_target = train_target.narrow(0,b,mini_batch_size)
            #print(batch_target)
            #print(output_comp)
            loss_comp = criterion_comp(output_comp, batch_target)
            
            loss = loss_img + loss_comp
            
            if b==0:
                print("loss = {}, loss_img = {}, loss_comp = {}".format(loss, loss_img, loss_comp))
                
            model_digit.zero_grad()
            model_comp.zero_grad()
            loss.backward()

            optimizer_comp.step()

In [54]:
model_digit = DigitNet(500)
model_comp = CompNet(model_digit)

print(sum(p.numel() for p in model_digit.parameters() if p.requires_grad))
print(sum(p.numel() for p in model_comp.parameters() if p.requires_grad))
print("training...")

train_model(model_digit=model_digit, model_comp=model_comp,
            train_input_1=train_input_1, train_input_2=train_input_2,
            train_classes_1=train_classes_1, train_classes_2=train_classes_2, 
            train_target=train_target)

152326
156028
training...
Epochs 0
loss = 5.314850330352783, loss_img = 4.6233930587768555, loss_comp = 0.6914573907852173
loss = 2.118473768234253, loss_img = 1.4228087663650513, loss_comp = 0.6956650614738464
loss = 2.359877347946167, loss_img = 1.7132670879364014, loss_comp = 0.6466102004051208
loss = 1.2319447994232178, loss_img = 0.648379921913147, loss_comp = 0.583564817905426
loss = 0.7087023854255676, loss_img = 0.2256259024143219, loss_comp = 0.4830764830112457
Epochs 5
loss = 0.6012747883796692, loss_img = 0.08515288680791855, loss_comp = 0.5161219239234924
loss = 0.4575989544391632, loss_img = 0.058027006685733795, loss_comp = 0.39957195520401
loss = 0.4420558214187622, loss_img = 0.034904494881629944, loss_comp = 0.40715134143829346
loss = 0.48331746459007263, loss_img = 0.06482523679733276, loss_comp = 0.41849222779273987
loss = 0.41554391384124756, loss_img = 0.02033272013068199, loss_comp = 0.39521118998527527
Epochs 10
loss = 0.41421470046043396, loss_img = 0.0098917968

In [98]:
def compute_nb_errors_siamese(model_digit, model_comp,
                              data_input_1, data_input_2, data_target, mini_batch_size=25):

    nb_data_errors = 0

    for b in range(0, data_input_1.size(0), mini_batch_size):
        
        output_comp = model_comp(data_input_1.narrow(0, b, mini_batch_size), data_input_2.narrow(0, b, mini_batch_size),train=False)
       
        output_comp = torch.round(output_comp)
       
       
        for k in range(mini_batch_size):
            #print(torch.eq(data_target[b + k], output_comp[k]))
            if torch.equal(torch.eq(data_target[b + k], output_comp[k]),torch.tensor([True,True])) == False:
                nb_data_errors = nb_data_errors + 1

    return nb_data_errors

In [99]:
def print_error_siamese(model_digit, model_comp, tr_input_1, tr_input_2, tr_target, te_input_1, te_input_2, te_target):
    print('train_error {:.02f}% test_error {:.02f}%'.format(
                compute_nb_errors_siamese(model_digit, model_comp, tr_input_1, tr_input_2, tr_target) / N * 100,
                compute_nb_errors_siamese(model_digit, model_comp, te_input_1, te_input_2, te_target) / N * 100))

In [100]:
print_error_siamese(model_digit, model_comp, train_input_1, train_input_2, train_target, test_input_1, test_input_2, test_target)

train_error 0.30% test_error 8.00%
