In [1]:
%reload_ext autoreload
%autoreload 1
import torch 
import sys
sys.path.append('..')
from torch import nn 
from torch.nn import functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from utils.loader import load
from utils.loader import PairSetMNIST

In [2]:
# load the dataset as a Dataset object
train_data = PairSetMNIST(train=True)
test_data  = PairSetMNIST(test=True) 

In [3]:
# baseline models

class Net2C(nn.Module):
    
    """
    Network which takes input as a two channel 14*14 image
    
    """

    def __init__(self, nb_hidden):
        
        super(Net2C, self).__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size = 5) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size = 3)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2) 
        
    def forward(self, x):
        
        # forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 3, stride = 1))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 3, stride = 3))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
    
        return x 

# inspired from LeNet5 but ssingle input image (concatenation of two channels)

class Netcat(nn.Module):
    
    """
    Network which processes the input to get a  : 1000 * 1 * 14 * 28 set 
    
    """
    
    def __init__(self, dim):
        
        super(Netcat, self).__init__()
        self.d = dim
        self.conv1 = nn.Conv2d(1, 6, kernel_size = 3, stride = 1)
        self.conv2 = nn.Conv2d(6, 16, kernel_size = 3, stride = 1)
        self.fc1 = nn.Linear(480, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)   
  
        
    def forward(self, x):
   
        x = torch.cat((x[:,0], x[:,1]), dim = 2).unsqueeze(dim = self.d)   # concatenate channels into 1 channel (input : 1000 * 1 * 14 * 28)
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size = 2, stride = 2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size = 2, stride = 1))
        x = F.relu(self.fc1(x.view(-1, 480)))
        x = F.relu(self.fc2(x.view(-1, 120)))
        x = self.fc3(x)

        return x

In [4]:
def train_binary (model, train_data, mini_batch_size=100, optimizer = optim.SGD,
                criterion = nn.CrossEntropyLoss(), n_epochs=50, eta=1e-1, lambda_l2 = 0):
    
    
    """
    Train network with auxiliary loss + weight sharing
    
    """
    # create data loader
    train_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
    
    model.train()
    optimizer = optimizer(model.parameters(), lr = eta)
    
    for e in range(n_epochs):
        epoch_loss = 0
        
        for i, data in enumerate(train_loader, 0):
            
            input_, target_, classes_ = data

            out = model(input_)
            out_loss  = criterion(out, target_)
            epoch_loss += out_loss
            
            if lambda_l2 != 0:
                for p in model.parameters():
                    epoch_loss += lambda_l2 * p.pow(2).sum() # add an l2 penalty term to the loss 
            
            optimizer.zero_grad()
            out_loss.backward()
            optimizer.step()
            
        print('Train Epoch: {}  | Loss {:.6f}'.format(
                e, epoch_loss.item()))
    
# test function #
    
def test_binary(model, test_data, mini_batch_size=100, criterion = nn.CrossEntropyLoss()):
    
    """
    Test function to calculate prediction accuracy of a cnn with auxiliary loss
    
    """
    # create tes laoder
    test_loader = DataLoader(test_data, batch_size=mini_batch_size, shuffle=True)
    model.eval()
    test_loss = 0
    nb_errors=0
    
    with torch.no_grad():
        
        for i, data in enumerate(test_loader, 0):
            
            input_, target_, classes_ = data
            output = model(input_) 
            batch_loss = criterion(output, target_)
            test_loss += batch_loss
            
            _, predicted_classes = output.max(1)
            for k in range(mini_batch_size):
                if target_[k] != predicted_classes[k]:
                    nb_errors = nb_errors + 1
                                   
             
        print('\nTest set | Loss: {:.4f} | Accuracy: {:.0f}% | # misclassified : {}/{}\n'.format(
        test_loss.item(), 100 * (len(test_data)-nb_errors)/len(test_data), nb_errors, len(test_data)))
        

In [5]:
###############################
###### Binary Classifier ######
###############################

model_1 = Net2C(200)
train_binary(model_1, train_data)
test_binary(model_1, test_data)

Train Epoch: 0  | Loss 7.012794
Train Epoch: 1  | Loss 6.691864
Train Epoch: 2  | Loss 6.455741
Train Epoch: 3  | Loss 6.505395
Train Epoch: 4  | Loss 6.116726
Train Epoch: 5  | Loss 5.560391


KeyboardInterrupt: 