In [9]:
%reload_ext autoreload
%autoreload 1
import torch 
import sys
sys.path.append('..')
from torch import nn 
from torch.nn import functional as F
from torch import optim
from utils.loader import load
from utils.loader import PairSetMNIST
from torch.utils.data import Dataset, DataLoader

In [10]:
# load the dataset as a Dataset object
train_data = PairSetMNIST(train=True,swap_channel = True)
test_data  = PairSetMNIST(test=True)

In [26]:
class STN_Net_aux(nn.Module):
    """
    Weight sharing + Auxiliary loss + spatial transformer
    
    """
    def __init__(self):
        super(STN_Net_aux, self).__init__()
        # convolutional weights for digit reocgnition shared for each image
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)
        self.dropout = nn.Dropout(0.0)
        
        # weights for binary classification 
        self.fc3 = nn.Linear(20, 230)
        self.fc4 = nn.Linear(230, 100)
        self.fc5 = nn.Linear(100, 2)
        
        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=3),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 2 * 2, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
        
    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 2 * 2)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x
        
    def forward(self, input_):    
        
        # split the 2-channel input into two 14*14 images
        x = input_[:, 0, :, :].view(-1, 1, 14, 14)
        y = input_[:, 1, :, :].view(-1, 1, 14, 14)
        
        # spatial transformer network
        x_ = self.stn(x)
        y_ =self.stn(y)
        
        
        # forward pass for the first image 
        x = F.relu(F.max_pool2d(self.bn1(self.conv1(x)), kernel_size=2, stride=2))
        x = F.relu(F.max_pool2d(self.bn2(self.conv2(x)), kernel_size=2, stride=2))
        x = F.relu(self.dropout(self.fc1(x.view(-1, 256))))
        x = self.dropout(self.fc2(x))
        
        # forward pass for the second image 
        y = F.relu(F.max_pool2d(self.bn1(self.conv1(y)), kernel_size=2, stride=2))
        y = F.relu(F.max_pool2d(self.bn2(self.conv2(y)), kernel_size=2, stride=2))
        y = F.relu(self.dropout(self.fc1(y.view(-1, 256))))
        y = self.dropout(self.fc2(y))
        
        # forward pass for the first image transformed
        x_ = F.relu(F.max_pool2d(self.bn1(self.conv1(x_)), kernel_size=2, stride=2))
        x_ = F.relu(F.max_pool2d(self.bn2(self.conv2(x_)), kernel_size=2, stride=2))
        x_ = F.relu(self.dropout(self.fc1(x_.view(-1, 256))))
        x_ = self.dropout(self.fc2(x_))
        
        # forward pass for the first image transformed
        y_ = F.relu(F.max_pool2d(self.bn1(self.conv1(y_)), kernel_size=2, stride=2))
        y_ = F.relu(F.max_pool2d(self.bn2(self.conv2(y_)), kernel_size=2, stride=2))
        y_ = F.relu(self.dropout(self.fc1(y_.view(-1, 256))))
        y_ = self.dropout(self.fc2(y_))
        
        # concatenate layers  
        z = torch.cat([x, y], 1)
        
        # concatenate layers images transformed 
        z_ = torch.cat([x_, y_], 1)
        
        z = F.relu(self.dropout(self.fc3(z)))
        z = F.relu(self.dropout(self.fc4(z)))
        z = self.dropout(self.fc5(z))
        
        z_ = F.relu(self.dropout(self.fc3(z_)))
        z_ = F.relu(self.dropout(self.fc4(z_)))
        z_ = self.dropout(self.fc5(z_))
        
        return (x+x_), (y+y_), (z+z_)

In [27]:
##### train function ######

def train_aux (model, train_data, mini_batch_size=100, optimizer = optim.SGD,
                criterion = nn.CrossEntropyLoss(), n_epochs=50, eta=1e-1, lambda_l2 = 0, alpha=0.5, beta=0.5):
    
    
    """
    Train network with auxiliary loss + weight sharing
    
    """
    # create data loader
    train_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
    
    model.train()
    optimizer = optimizer(model.parameters(), lr = eta)
    
    for e in range(n_epochs):
        epoch_loss = 0
        
        for i, data in enumerate(train_loader, 0):
            
            input_, target_, classes_ = data
            class_1, class_2,out= model(input_)
            aux_loss1 = criterion(class_1, classes_[:,0])
            aux_loss2 = criterion(class_2, classes_[:,1])
            #aux_loss1_T = criterion(class_1_T, classes_[:,0])
            #aux_loss2_T = criterion(class_2_T, classes_[:,1])
            out_loss  = criterion(out, target_)
            #out_T_loss = criterion(out_T, target_)
            net_loss = (alpha * (out_loss) + beta * (aux_loss1 + aux_loss2 ))
            epoch_loss += net_loss
            
            if lambda_l2 != 0:
                for p in model.parameters():
                    epoch_loss += lambda_l2 * p.pow(2).sum() # add an l2 penalty term to the loss 
            
            optimizer.zero_grad()
            net_loss.backward()
            optimizer.step()
            
        print('Train Epoch: {}  | Loss {:.6f}'.format(
                e, epoch_loss.item()))
        
#########################################################################################################################
#########################################################################################################################

### test function  ###

def test_aux(model, test_data, mini_batch_size=100, criterion = nn.CrossEntropyLoss()):
    
    """
    Test function to calculate prediction accuracy of a cnn with auxiliary loss
    
    """
    
    # create test laoder
    test_loader = DataLoader(test_data, batch_size=mini_batch_size, shuffle=True)
    
    model.eval()
    test_loss = 0
    nb_errors=0
    
    with torch.no_grad():
        
        for i, data in enumerate(test_loader, 0):
            input_, target_, classes_ = data
            
            _,_ ,output = model(input_) 
            batch_loss = criterion(output, target_)
            test_loss += batch_loss
            
            _, predicted_classes = output.max(1)
            for k in range(mini_batch_size):
                if target_[k] != predicted_classes[k]:
                    nb_errors = nb_errors + 1
                                   
             
        print('\nTest set | Loss: {:.4f} | Accuracy: {:.0f}% | # misclassified : {}/{}\n'.format(
        test_loss.item(), 100 * (len(test_data)-nb_errors)/len(test_data), nb_errors, len(test_data)))

In [28]:
model = STN_Net_aux()
train_aux(model, train_data,optimizer=optim.Adam,eta = 0.001)
test_aux(model,test_data)

Train Epoch: 0  | Loss 28.362577
Train Epoch: 1  | Loss 8.908897
Train Epoch: 2  | Loss 4.895871
Train Epoch: 3  | Loss 2.951210
Train Epoch: 4  | Loss 1.772059
Train Epoch: 5  | Loss 1.030429
Train Epoch: 6  | Loss 0.564917
Train Epoch: 7  | Loss 0.336092
Train Epoch: 8  | Loss 0.215338
Train Epoch: 9  | Loss 0.146005
Train Epoch: 10  | Loss 0.107445
Train Epoch: 11  | Loss 0.082549
Train Epoch: 12  | Loss 0.069765
Train Epoch: 13  | Loss 0.058600
Train Epoch: 14  | Loss 0.052406
Train Epoch: 15  | Loss 0.048222
Train Epoch: 16  | Loss 0.039775
Train Epoch: 17  | Loss 0.035433
Train Epoch: 18  | Loss 0.031399
Train Epoch: 19  | Loss 0.028086
Train Epoch: 20  | Loss 0.025468
Train Epoch: 21  | Loss 0.023863
Train Epoch: 22  | Loss 0.021607
Train Epoch: 23  | Loss 0.019417
Train Epoch: 24  | Loss 0.018275
Train Epoch: 25  | Loss 0.016255
Train Epoch: 26  | Loss 0.015382
Train Epoch: 27  | Loss 0.014249
Train Epoch: 28  | Loss 0.013081
Train Epoch: 29  | Loss 0.012044
Train Epoch: 30  | 