In [1]:
import torch
from torch import nn

import dlc_practical_prologue as prologue

In [2]:
N=1000
epochs=25
mini_batch_size = 100

In [3]:
train_input, train_target, train_classes, \
test_input, test_target, test_classes = \
    prologue.generate_pair_sets(N)

In [9]:
class SharedWeight(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(2, stride=2))
        
        #predict implicitely the class
        self.fc1 = nn.Sequential(
            nn.Linear(32 * 2 * 2, 100),
            nn.ReLU(inplace=True),
            nn.Linear(100, 10),
            nn.Softmax(1))
        
        #predict the output from concatenated classes predictions
        self.fc2 = nn.Sequential(
            nn.Linear(20, 10),
            nn.ReLU(inplace=True),
            nn.Linear(10,2),
            nn.Softmax(1))
        
    def forward_once(self, x):
        #call the classification block on an individual input
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output
    
    def forward(self, input1, input2):
        #call the network on both input, combine the outputs and final output
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        output = self.fc2(torch.cat((output1, output2), 1))
        return output, output1, output2

def train_model(model, input_, target_, classes_=None, learn_rate_= 1e-2, lambda_=0.1, mini_batch_size=100, nb_epochs = 25):
    #classification & binary output
    criterion = nn.CrossEntropyLoss()
    
    #stochastic gradient descent
    optim = torch.optim.SGD(model.parameters(), lr=learn_rate_, momentum=0.9)

    for e in range(nb_epochs):
        for b in range(0, input_.size(0), mini_batch_size):
            #split the input in two, return the binary value and the classes of both input
            output, class_0, class_1 = model(input_.narrow(0, b, mini_batch_size)[:,0,:].unsqueeze(1),
                                             input_.narrow(0, b, mini_batch_size)[:,1,:].unsqueeze(1))
            
            if classes_ is None:
                #no auxiliary loss
                loss = criterion(output, target_.narrow(0, b, mini_batch_size))
            else:
                #auxiliary loss
                loss = criterion(output, target_.narrow(0, b, mini_batch_size)) \
                        + lambda_*criterion(class_0, classes_.narrow(0, b, mini_batch_size)[:,0]) \
                        + lambda_*criterion(class_1, classes_.narrow(0, b, mini_batch_size)[:,1])
            
            #backprop
            optim.zero_grad()
            loss.backward()
            optim.step()
        print("epoch ", e, " loss : ", loss)
        
def compute_nb_errors(model, input_, target_, mini_batch_size):
    nb_error = 0
    
    for b in range(0, input_.size(0), mini_batch_size):
        output, _, _ = model(input_.narrow(0, b, mini_batch_size)[:,0,:].unsqueeze(1),
                       input_.narrow(0, b, mini_batch_size)[:,1,:].unsqueeze(1))
        _,pred = output.max(1)
        
        for k in range(mini_batch_size):
            if target_[b+k] != pred[k]:
                nb_error += 1
                
    return nb_error

In [5]:
Siamese = SharedWeight()

In [6]:
train_model(Siamese, train_input, train_target, train_classes, learn_rate_= 1e-2, lambda_=0.1, mini_batch_size=mini_batch_size, nb_epochs = 25)

  allow_unreachable=True)  # allow_unreachable flag


epoch  0  loss :  tensor(1.1372, grad_fn=<AddBackward0>)
epoch  1  loss :  tensor(1.1156, grad_fn=<AddBackward0>)
epoch  2  loss :  tensor(1.1019, grad_fn=<AddBackward0>)
epoch  3  loss :  tensor(1.0952, grad_fn=<AddBackward0>)
epoch  4  loss :  tensor(1.0772, grad_fn=<AddBackward0>)
epoch  5  loss :  tensor(1.0743, grad_fn=<AddBackward0>)
epoch  6  loss :  tensor(1.0643, grad_fn=<AddBackward0>)
epoch  7  loss :  tensor(1.0572, grad_fn=<AddBackward0>)
epoch  8  loss :  tensor(1.0533, grad_fn=<AddBackward0>)
epoch  9  loss :  tensor(1.0451, grad_fn=<AddBackward0>)
epoch  10  loss :  tensor(1.0388, grad_fn=<AddBackward0>)
epoch  11  loss :  tensor(1.0315, grad_fn=<AddBackward0>)
epoch  12  loss :  tensor(1.0263, grad_fn=<AddBackward0>)
epoch  13  loss :  tensor(1.0285, grad_fn=<AddBackward0>)
epoch  14  loss :  tensor(1.0118, grad_fn=<AddBackward0>)
epoch  15  loss :  tensor(1.0091, grad_fn=<AddBackward0>)
epoch  16  loss :  tensor(1.0027, grad_fn=<AddBackward0>)
epoch  17  loss :  tenso

In [10]:
nb_test_errors = compute_nb_errors(Siamese, test_input, test_target, mini_batch_size)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                  nb_test_errors, test_input.size(0)))

test error Net 19.90% 199/1000
