In [1]:
%reload_ext autoreload
%autoreload 1
import torch 
import sys
sys.path.append('..')
from torch import nn 
from torch.nn import functional as F
from torch import optim
#from utils.old_loader import load
#from utils.old_loader import PairSetMNIST
from utils.new_loader import load,PairSetMNIST,Training_set,Test_set, Training_set_split,Validation_set
from torch.utils.data import Dataset, DataLoader

In [3]:
# load the dataset as a Dataset object
data = PairSetMNIST( rotate=True,translate=True,swap_channel = True)
train_data_ = Training_set(data)
test_data_ = Test_set(data)
print(train_data_.train_input.shape)
print(test_data_.test_input.shape)

torch.Size([16000, 2, 14, 14])
torch.Size([1000, 2, 14, 14])


In [4]:
class LeNet_aux_sequential(nn.Module):
    """
    Weight sharing + Auxiliary loss
    
    """
    def __init__(self):
        super(LeNet_aux_sequential, self).__init__()
        # convolutional weights for digit reocgnition shared for each image
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        #self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        #self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)
        #self.dropout = nn.Dropout()
        
        # weights for binary classification 
        self.fc3 = nn.Linear(20, 60)
        self.fc4 = nn.Linear(60, 90)
        self.fc5 = nn.Linear(90, 2)
        
    def forward(self, input_):    
        
        # split the 2-channel input into two 14*14 images
        x = input_[:, 0, :, :].view(-1, 1, 14, 14)
        y = input_[:, 1, :, :].view(-1, 1, 14, 14)
        
        # forward pass for the first image 
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2, stride=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        
        # forward pass for the second image 
        y = F.relu(F.max_pool2d(self.conv1(y), kernel_size=2, stride=2))
        y = F.relu(F.max_pool2d(self.conv2(y), kernel_size=2, stride=2))
        y = F.relu(self.fc1(y.view(-1, 256)))
        y = self.fc2(y)
        
        # concatenate layers  
        z = torch.cat([x, y], 1)
        
        z = F.relu(self.fc3(z))
        z = F.relu(self.fc4(z))
        z = self.fc5(z)
        
        return x, y, z

In [5]:
##### train function ######

def train_aux (model, train_data, mini_batch_size=100, optimizer = optim.SGD,
                criterion = nn.CrossEntropyLoss(), n_epochs=50, eta=1e-1, lambda_l2 = 0, alpha=1.0, beta=0.5):
    
    
    """
    Train network with auxiliary loss + weight sharing
    
    """
    # create data loader
    train_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
    
    model.train(True)
    optimizer = optimizer(model.parameters(), lr = eta)
    
    for e in range(n_epochs):
        epoch_loss = 0
        
        for i, data in enumerate(train_loader, 0):
            
            input_, target_, classes_ = data
            class_1, class_2, out = model(input_)
            aux_loss1 = criterion(class_1, classes_[:,0])
            aux_loss2 = criterion(class_2, classes_[:,1])
            out_loss  = criterion(out, target_)
            net_loss = (alpha * (out_loss) + beta * (aux_loss1 + aux_loss2) )
            epoch_loss += net_loss
            
            if lambda_l2 != 0:
                for p in model.parameters():
                    epoch_loss += lambda_l2 * p.pow(2).sum() # add an l2 penalty term to the loss 
            
            optimizer.zero_grad()
            net_loss.backward()
            optimizer.step()
            
        print('Train Epoch: {}  | Loss {:.6f}'.format(
                e, epoch_loss.item()))
        
#########################################################################################################################
#########################################################################################################################

### test function  ###

def test_aux(model, test_data, mini_batch_size=100, criterion = nn.CrossEntropyLoss()):
    
    """
    Test function to calculate prediction accuracy of a cnn with auxiliary loss
    
    """
    
    # create test laoder
    test_loader = DataLoader(test_data, batch_size=mini_batch_size, shuffle=True)
    
    model.train(False)
    test_loss = 0
    nb_errors = 0
    
    with torch.no_grad():
        
        for i, data in enumerate(test_loader, 0):
            input_, target_, classes_ = data
            
            _, _, output = model(input_) 
            batch_loss = criterion(output, target_)
            test_loss += batch_loss
            
            _, predicted_classes = output.max(1)
            for k in range(mini_batch_size):
                if target_[k] != predicted_classes[k]:
                    nb_errors = nb_errors + 1
                                   
             
        print('\nTest set | Loss: {:.4f} | Accuracy: {:.0f}% | # misclassified : {}/{}\n'.format(
        test_loss.item(), 100 * (len(test_data)-nb_errors)/len(test_data), nb_errors, len(test_data)))
        

In [6]:
model = LeNet_aux_sequential()
train_aux(model, train_data,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model,test_data)

NameError: name 'train_data' is not defined

In [9]:
test_aux(model,test_data_)

tensor([[ -7.6450,   6.0259],
        [-14.3836,  17.4279],
        [-11.4162,  13.0048],
        [  9.7627, -11.0057],
        [-12.2708,  12.8698],
        [-14.8284,  15.0460],
        [ -8.2375,   4.6712],
        [ 11.5759, -12.3448],
        [  5.0701,  -5.7441],
        [-10.9261,   9.3787],
        [ 12.2722, -14.0312],
        [ -9.5158,  11.7247],
        [  2.8210,  -3.7193],
        [ 10.6935, -12.9575],
        [ -6.1383,   7.4739],
        [-13.1602,  15.7676],
        [-12.1984,  11.4457],
        [ 10.4666, -12.9163],
        [-11.8702,  10.9577],
        [  2.7235,  -2.1931],
        [  2.0712,  -2.8221],
        [-10.8293,  10.1250],
        [  3.9596,  -5.8641],
        [-15.4876,  17.5666],
        [ -4.4942,   3.1257],
        [ 12.6149, -15.0957],
        [ 11.0826, -11.7689],
        [-12.8652,  13.0225],
        [ -7.7955,   9.6393],
        [ 16.4775, -20.7470],
        [ -3.0086,   3.4873],
        [-11.0346,  13.6261],
        [ -9.0109,  10.9957],
        [ 

        [-13.6140,  17.0595]])
tensor(0.5717)
tensor([[ -2.2749,   0.7365],
        [-15.6520,  19.2018],
        [ -0.7886,  -0.0968],
        [-13.7030,  16.5409],
        [  3.1169,  -6.2309],
        [ -9.4095,   5.4086],
        [ -1.0080,  -0.9229],
        [ -8.7685,   5.9771],
        [ -5.4428,   6.9744],
        [-15.1411,  15.4742],
        [ 16.3631, -19.1887],
        [ -6.3139,   5.9597],
        [ 10.2049, -12.7251],
        [  8.5274, -10.4413],
        [  3.3269,  -4.2558],
        [  3.1210,  -2.8185],
        [ -9.0053,   7.9375],
        [-20.9535,  22.1994],
        [ -9.7699,   7.8464],
        [-19.1908,  19.1313],
        [  1.4035,   0.1971],
        [-12.2100,  11.2159],
        [-10.1267,  12.1243],
        [-13.0348,  16.4647],
        [ 10.4637, -12.0258],
        [ -7.6591,   7.2662],
        [  0.5039,  -2.0527],
        [ -8.7055,   7.4470],
        [-18.0079,  20.0993],
        [  2.2516,  -3.9247],
        [  7.0077,  -8.0494],
        [ -5.5721,   5.5

        [  2.2909,  -1.5637]])
tensor(0.4300)
tensor([[ -6.3072,   7.1686],
        [-14.4355,  18.0834],
        [ -7.4932,   6.4322],
        [  0.2656,  -2.1846],
        [ -3.9658,   2.6991],
        [-16.3203,  15.5981],
        [ -5.9051,   7.7982],
        [-11.2750,  11.7481],
        [  6.8174,  -8.9065],
        [ -1.6601,  -1.1805],
        [ 12.7239, -16.2744],
        [-10.8555,  11.8378],
        [  0.6413,  -3.1425],
        [ -2.1916,   3.5473],
        [  9.5723, -11.1539],
        [  0.8204,  -1.8447],
        [ -5.2202,   4.1392],
        [  7.3554,  -9.9161],
        [-13.0423,  11.4197],
        [ -7.4955,   9.0128],
        [ -8.2873,  10.4194],
        [-10.9600,  10.4351],
        [ -9.8543,   5.9600],
        [ 10.9597, -15.9880],
        [  5.7018,  -7.7091],
        [  8.5727, -10.1743],
        [ -4.1930,   6.0155],
        [  2.1173,  -3.8518],
        [  5.3745,  -7.3765],
        [ 13.0599, -17.1480],
        [ -5.5758,   6.8285],
        [ 12.7410, -14.8

        [-11.3768,   9.5490]])
tensor(0.4362)
tensor([[  4.9699,  -5.6114],
        [-10.5300,   9.1245],
        [-14.3724,  17.8178],
        [ -3.8646,   2.4962],
        [ -8.3600,   9.3414],
        [  4.0931,  -5.5268],
        [  7.8571,  -8.9813],
        [ -7.2875,   6.4636],
        [ -7.3781,   8.7849],
        [-12.6316,  16.8181],
        [-18.1429,  18.5627],
        [ -4.5186,   5.5733],
        [-14.0643,  11.0320],
        [-14.5263,  10.7458],
        [ -2.1070,  -0.7787],
        [  5.8459,  -8.3998],
        [  8.0678,  -8.3399],
        [ 12.2428, -13.5037],
        [-12.5407,  13.7983],
        [  7.8980,  -9.0606],
        [ -7.5168,   9.8564],
        [ -9.0155,   9.7210],
        [-16.1544,  20.8699],
        [  0.7040,  -0.0537],
        [ -5.4769,   5.9588],
        [  7.3296,  -8.8758],
        [  2.3483,  -5.3956],
        [ -2.7689,   2.6029],
        [  4.6817,  -7.1204],
        [-17.7503,  18.1002],
        [  9.6787, -11.5878],
        [-12.4702,  13.9

In [6]:
test_aux(model,test_data)

tensor(0.5474)
tensor(1.0135)
tensor(0.4625)
tensor(0.3606)
tensor(0.8515)
tensor(0.5545)
tensor(0.1612)
tensor(0.3685)
tensor(0.5138)
tensor(0.6354)

Test set | Loss: 5.4690 | Accuracy: 92% | # misclassified : 81/1000



In [5]:
model_augm = LeNet_aux_sequential()
train_aux(model_augm, train_data,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model_augm,test_data)

Train Epoch: 0  | Loss 145.986572


KeyboardInterrupt: 

In [7]:
test_aux(model_augm,test_data)

tensor(16.7076)
tensor(51.8632)
tensor(34.4272)
tensor(31.2004)
tensor(30.8285)
tensor(16.4995)
tensor(48.4509)
tensor(4.3330)
tensor(20.7743)
tensor(51.4832)

Test set | Loss: 306.5677 | Accuracy: 93% | # misclassified : 71/1000



In [8]:
model_augm = LeNet_aux_sequential()
train_aux(model_augm, train_data,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model_augm,test_data)

Train Epoch: 0  | Loss 144.789093
Train Epoch: 1  | Loss 66.049606
Train Epoch: 2  | Loss 46.621475
Train Epoch: 3  | Loss 34.935658
Train Epoch: 4  | Loss 27.213629
Train Epoch: 5  | Loss 21.440855
Train Epoch: 6  | Loss 17.295500
Train Epoch: 7  | Loss 14.031562
Train Epoch: 8  | Loss 12.362149
Train Epoch: 9  | Loss 9.565265
Train Epoch: 10  | Loss 7.930093
Train Epoch: 11  | Loss 5.841302
Train Epoch: 12  | Loss 5.092928
Train Epoch: 13  | Loss 3.984188
Train Epoch: 14  | Loss 3.861875
Train Epoch: 15  | Loss 3.080203
Train Epoch: 16  | Loss 6.638016
Train Epoch: 17  | Loss 6.167600
Train Epoch: 18  | Loss 3.215981
Train Epoch: 19  | Loss 1.586827
Train Epoch: 20  | Loss 0.984653
Train Epoch: 21  | Loss 0.723888
Train Epoch: 22  | Loss 0.616509
Train Epoch: 23  | Loss 0.505288
Train Epoch: 24  | Loss 0.464645
Train Epoch: 25  | Loss 0.422698
Train Epoch: 26  | Loss 0.308345
Train Epoch: 27  | Loss 0.292129
Train Epoch: 28  | Loss 0.262822
Train Epoch: 29  | Loss 0.227796
Train Epoc

AttributeError: 'LeNet_aux_sequential' object has no attribute 'Train'

In [10]:
test_aux(model_augm,test_data)

tensor(39.4132)
tensor(55.3103)
tensor(54.7048)
tensor(50.2726)
tensor(35.1218)
tensor(51.1670)
tensor(42.4003)
tensor(5.7441)
tensor(21.5236)
tensor(97.1354)

Test set | Loss: 452.7932 | Accuracy: 92% | # misclassified : 85/1000



In [13]:
model_augm1 = LeNet_aux_sequential()
train_aux(model_augm1, train_data,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model_augm1,test_data)

Train Epoch: 0  | Loss 227.319672
Train Epoch: 1  | Loss 94.127136
Train Epoch: 2  | Loss 62.995865
Train Epoch: 3  | Loss 44.641670
Train Epoch: 4  | Loss 35.266556
Train Epoch: 5  | Loss 24.944170
Train Epoch: 6  | Loss 19.440481
Train Epoch: 7  | Loss 13.992777
Train Epoch: 8  | Loss 12.863868
Train Epoch: 9  | Loss 10.495307
Train Epoch: 10  | Loss 7.537069
Train Epoch: 11  | Loss 9.537624
Train Epoch: 12  | Loss 5.903196
Train Epoch: 13  | Loss 2.957053
Train Epoch: 14  | Loss 1.571920
Train Epoch: 15  | Loss 1.129858
Train Epoch: 16  | Loss 0.816844
Train Epoch: 17  | Loss 0.609346
Train Epoch: 18  | Loss 0.477440
Train Epoch: 19  | Loss 0.381855
Train Epoch: 20  | Loss 0.321070
Train Epoch: 21  | Loss 0.277929
Train Epoch: 22  | Loss 0.238281
Train Epoch: 23  | Loss 0.199922
Train Epoch: 24  | Loss 0.174901
Train Epoch: 25  | Loss 0.147749
Train Epoch: 26  | Loss 0.134503
Train Epoch: 27  | Loss 0.115223
Train Epoch: 28  | Loss 0.102293
Train Epoch: 29  | Loss 0.089915
Train Epo

In [21]:
model = LeNet_aux_sequential()
train_aux(model, train_data,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model,test_data)

Train Epoch: 0  | Loss 52.288834
Train Epoch: 1  | Loss 27.401350
Train Epoch: 2  | Loss 16.158192
Train Epoch: 3  | Loss 11.540616
Train Epoch: 4  | Loss 8.976239
Train Epoch: 5  | Loss 7.268036
Train Epoch: 6  | Loss 5.936383
Train Epoch: 7  | Loss 4.643144
Train Epoch: 8  | Loss 4.066018
Train Epoch: 9  | Loss 3.231797
Train Epoch: 10  | Loss 2.423669
Train Epoch: 11  | Loss 2.132078
Train Epoch: 12  | Loss 2.231938
Train Epoch: 13  | Loss 1.281895
Train Epoch: 14  | Loss 0.918143
Train Epoch: 15  | Loss 0.716443
Train Epoch: 16  | Loss 0.569584
Train Epoch: 17  | Loss 0.486224
Train Epoch: 18  | Loss 0.400498
Train Epoch: 19  | Loss 0.331986
Train Epoch: 20  | Loss 0.277661
Train Epoch: 21  | Loss 0.216684
Train Epoch: 22  | Loss 0.188167
Train Epoch: 23  | Loss 0.157810
Train Epoch: 24  | Loss 0.145424
Train Epoch: 25  | Loss 0.122730
Train Epoch: 26  | Loss 0.101887
Train Epoch: 27  | Loss 0.096355
Train Epoch: 28  | Loss 0.083998
Train Epoch: 29  | Loss 0.074811
Train Epoch: 30 

In [35]:
test_aux(model,test_data)

tensor([[ 1.2606e+02, -2.3549e+02],
        [-3.8099e+02,  1.1610e+02],
        [ 4.6878e+02, -4.4992e+02],
        [-6.3798e+02,  5.8433e+02],
        [-6.8643e+02,  5.0607e+02],
        [-5.3949e+02,  4.1741e+02],
        [ 6.2688e+02, -5.6452e+02],
        [ 9.3298e+02, -8.7079e+02],
        [-8.7116e+02,  5.7909e+02],
        [ 3.0181e+02, -2.5192e+02],
        [-5.8592e+02,  3.6442e+02],
        [-8.1297e+02,  9.0083e+02],
        [ 5.8139e+02, -5.7731e+02],
        [-6.9027e+02,  6.1560e+02],
        [-1.8523e+02,  1.9696e+02],
        [ 1.6001e+02, -2.0363e+02],
        [ 5.2831e+02, -5.7059e+02],
        [ 1.8419e+02, -3.5322e+02],
        [-6.3580e+02,  6.2021e+02],
        [-3.9868e+02,  6.4252e+02],
        [ 9.3004e+02, -7.2959e+02],
        [-4.8559e+02,  4.2954e+02],
        [-1.0966e+03,  1.0614e+03],
        [ 8.7829e+02, -6.1730e+02],
        [-2.7250e+02,  2.2932e+02],
        [-6.6990e+02,  5.3674e+02],
        [ 1.1393e+02, -1.8887e+02],
        [ 6.9202e+02, -6.479

        [  188.3135,  -246.8082]])
tensor(63.5310)
tensor([[  928.9849,  -683.6309],
        [  320.7217,  -416.5012],
        [  811.7279,  -742.1561],
        [  638.7829,  -487.7584],
        [  159.0321,  -191.6135],
        [  815.2189,  -878.2554],
        [   -6.3955,    35.6924],
        [  295.8208,  -207.8483],
        [ -328.6797,   286.3041],
        [-1327.4486,  1119.3771],
        [ -305.5297,   534.7176],
        [  627.5497,  -598.6666],
        [  289.2734,  -355.6018],
        [ -795.9886,   503.5179],
        [  468.1767,  -364.3541],
        [ -235.0885,   104.5526],
        [  258.9228,  -248.7329],
        [ -563.2453,   474.6917],
        [ -587.8633,   567.8807],
        [ -235.8471,   413.9138],
        [ -262.9609,   364.3360],
        [-1657.7849,  1256.0875],
        [ -580.4194,   384.1262],
        [-1373.8788,  1235.5583],
        [  542.0359,  -596.5303],
        [  -87.6832,    -9.4366],
        [ -262.0928,   185.3386],
        [ -419.3000,   281.3993

        [  -70.7619,    52.4617]])
tensor(48.6079)
tensor([[  795.4136,  -628.0977],
        [  403.4852,  -359.6679],
        [  617.4406,  -638.0015],
        [   77.7365,  -105.3470],
        [  680.2120,  -558.2063],
        [ -198.5392,   107.1291],
        [-1195.3448,  1006.0453],
        [  507.8616,  -299.0545],
        [ -510.6568,   386.9263],
        [ -848.0081,   751.9426],
        [-1130.1597,   861.4144],
        [ -550.2872,   513.0112],
        [  617.8946,  -721.9213],
        [ -589.3866,   471.9518],
        [  558.4028,  -473.4546],
        [  327.7458,  -374.3926],
        [  535.6418,  -409.0899],
        [-1324.7732,  1011.8889],
        [ -441.6898,   278.2787],
        [  -71.1150,   205.0027],
        [ -694.3903,   600.2227],
        [ -853.5543,   647.1109],
        [ -471.6271,   418.6380],
        [-1452.7039,  1179.0858],
        [  254.4881,  -196.9819],
        [ 1034.1465,  -940.6293],
        [  606.9845,  -583.7686],
        [-1205.3657,   973.4696

        [  535.2621,  -487.4341]])
tensor(30.3808)
tensor([[  964.6653, -1107.7129],
        [ -752.8280,   613.4736],
        [ -519.3843,   430.3716],
        [ -339.3129,   240.6856],
        [ -673.5342,   561.4585],
        [ 1127.3458, -1095.5304],
        [  896.2012,  -753.2390],
        [  616.9594,  -507.6483],
        [  322.9309,  -478.6940],
        [ -465.6600,   425.4844],
        [ -323.8486,   131.4357],
        [  675.4183,  -594.9023],
        [  314.9901,  -275.9276],
        [   10.6894,  -161.3256],
        [  110.7990,  -282.2849],
        [  824.2146,  -706.9068],
        [ -926.9305,   755.7000],
        [  211.0645,  -311.0150],
        [  201.7484,   -13.3052],
        [  128.2714,   -94.4762],
        [ -944.5203,  1073.9138],
        [  879.1106,  -667.1950],
        [ -877.3083,   776.7773],
        [ -861.0989,   984.0383],
        [ -460.6888,   285.7965],
        [ -119.2853,   270.4872],
        [-1235.7900,  1192.2003],
        [ -778.2770,   854.6116

In [36]:
test_aux(model_augm1,test_data)

tensor([[ 4.3191e+02, -3.9012e+02],
        [-2.7488e+02, -3.0023e+02],
        [-1.0690e+02, -2.5132e+02],
        [-1.7360e+03,  9.4973e+02],
        [-2.5760e+03,  1.8399e+03],
        [ 4.7097e+02, -6.9781e+02],
        [ 1.5339e+03, -1.7109e+03],
        [ 5.2053e+02, -4.5803e+02],
        [ 1.3308e+03, -9.4802e+02],
        [ 1.2168e+03, -1.0882e+03],
        [-2.8960e+02, -4.2389e+02],
        [ 1.8582e+02, -7.6704e+02],
        [-2.3771e+03,  1.2449e+03],
        [-6.8772e+02,  3.7871e+02],
        [ 1.5190e+03, -1.7337e+03],
        [ 1.0042e+03, -1.0888e+03],
        [-5.6819e+02,  5.8973e+02],
        [-4.5069e+02,  2.3102e+00],
        [-3.7436e+02, -1.7293e+02],
        [ 1.5119e+02, -6.7040e+02],
        [ 1.0976e+03, -1.4632e+03],
        [-1.9221e+03,  1.3102e+03],
        [ 6.4187e+02, -1.0158e+03],
        [-7.9721e+02,  5.4816e+02],
        [ 1.9972e+03, -2.3202e+03],
        [ 1.3073e+03, -1.6774e+03],
        [-1.0593e+03,  4.2860e+02],
        [-3.2325e+03,  2.206

        [-1283.6094,   909.5790]])
tensor(68.8898)
tensor([[ 1953.4080, -1814.4861],
        [ 1233.4987, -1511.4124],
        [-3868.7161,  3444.2551],
        [-2186.6619,  1026.2931],
        [ 1215.6813, -1227.6450],
        [ -260.2136,    78.6271],
        [ 1558.4568, -1554.7444],
        [-2839.4358,  1928.4863],
        [-2952.1584,  1930.2213],
        [  264.0259,  -628.5031],
        [ -558.5547,   633.9905],
        [  661.9775,  -871.2388],
        [-1831.1993,  1266.1282],
        [-2553.4780,  2078.2415],
        [-2372.0896,  1510.9519],
        [ 2222.8296, -2510.3357],
        [-2676.3909,  2253.1714],
        [ -639.5616,   427.9326],
        [-2151.4087,  1532.4844],
        [   29.8777,  -688.6841],
        [ 1222.9008, -1413.6522],
        [  143.7982,  -329.1919],
        [  121.0471,  -684.6375],
        [-2112.0632,   960.5212],
        [ -445.9513,  -121.0913],
        [   79.6646,  -261.6071],
        [   23.2941,  -525.8381],
        [-2782.7217,  2241.4080

        [  304.1106,  -343.8319]])
tensor(44.0102)
tensor([[-2857.4238,  1716.9143],
        [ -284.6371,   -19.6728],
        [-1931.3546,  1334.1797],
        [  287.6172,  -654.8644],
        [ -400.4646,   263.4863],
        [  518.2298,  -836.5864],
        [-2300.4871,  2367.6677],
        [  335.0984, -1266.7671],
        [ 1458.0865, -1617.2068],
        [ 1248.5763, -1501.4203],
        [-1210.4401,   170.9094],
        [-2438.4343,  1639.9663],
        [-2418.8687,  1419.8496],
        [-1937.0077,  1336.8892],
        [ 1326.7371, -1414.0204],
        [   35.3080,    99.9286],
        [ 1011.2781, -1164.3051],
        [ 1473.9044, -1479.0037],
        [-1691.1847,  1720.7375],
        [-2271.5630,  2081.9773],
        [-1218.2216,  1145.9031],
        [  599.7632,  -707.4600],
        [  758.6294,  -912.1805],
        [ 1087.9805, -1022.4487],
        [ 1181.5626, -1357.2399],
        [ 1243.5170, -1421.8495],
        [-1459.2328,  1029.8572],
        [-1768.0145,  1249.4604

        [ -457.7156,   388.1994]])
tensor(168.0147)
tensor([[  682.5480, -1053.7289],
        [-1590.2616,  1462.3582],
        [-1528.3640,   799.0468],
        [-1919.5287,  1172.6770],
        [ 1492.4008, -2033.1248],
        [-1168.4430,    -9.2524],
        [  392.5664,  -955.5839],
        [ 1719.7660, -1646.9036],
        [  961.7961, -1030.1614],
        [  675.5200,  -736.3019],
        [-2767.3484,  2103.0935],
        [ -699.5977,   669.8929],
        [-1621.3280,   913.6268],
        [-1095.7317,   628.9807],
        [ -260.4222,  -310.3430],
        [ -776.8713,   514.7580],
        [-1706.8740,  1267.7668],
        [ -160.8735,  -105.7458],
        [  239.6414,  -376.9294],
        [  879.5784, -1054.1864],
        [-1361.4961,   668.4905],
        [-2650.5833,  2335.2354],
        [-2283.5459,  1208.8531],
        [-1120.1545,   736.7142],
        [  875.8707, -1296.2434],
        [  819.5605, -1133.8815],
        [ 1610.3190, -1804.8093],
        [-2310.9695,  1547.364

In [11]:
model_augment = LeNet_aux_sequential()
train_aux(model_augment, train_data_,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model_augment,test_data_)

Train Epoch: 0  | Loss 236.871338
Train Epoch: 1  | Loss 105.768906
Train Epoch: 2  | Loss 70.415672
Train Epoch: 3  | Loss 50.054382
Train Epoch: 4  | Loss 37.400730
Train Epoch: 5  | Loss 30.210844
Train Epoch: 6  | Loss 24.576738
Train Epoch: 7  | Loss 19.424099
Train Epoch: 8  | Loss 15.611818
Train Epoch: 9  | Loss 12.939423
Train Epoch: 10  | Loss 9.982206
Train Epoch: 11  | Loss 9.027804
Train Epoch: 12  | Loss 8.287337
Train Epoch: 13  | Loss 6.866757
Train Epoch: 14  | Loss 5.736903
Train Epoch: 15  | Loss 3.973396
Train Epoch: 16  | Loss 6.437736
Train Epoch: 17  | Loss 7.067785
Train Epoch: 18  | Loss 2.554822
Train Epoch: 19  | Loss 0.847566
Train Epoch: 20  | Loss 0.875374
Train Epoch: 21  | Loss 0.485584
Train Epoch: 22  | Loss 0.353842
Train Epoch: 23  | Loss 0.282191
Train Epoch: 24  | Loss 0.240473
Train Epoch: 25  | Loss 0.207667
Train Epoch: 26  | Loss 0.183488
Train Epoch: 27  | Loss 0.160679
Train Epoch: 28  | Loss 0.142497
Train Epoch: 29  | Loss 0.125831
Train Ep

        [ -5.4256,   4.6957]])
tensor(0.1351)
tensor([[-19.1355,  14.7993],
        [ 24.5990, -19.3458],
        [ 31.4184, -28.5260],
        [-54.1782,  47.8861],
        [ -7.7839,  10.3008],
        [ 19.1294, -14.4786],
        [ 12.2963, -15.6178],
        [-31.6332,  26.1976],
        [-58.0741,  52.3938],
        [  1.9750,  -2.1067],
        [ -4.8770,   5.3786],
        [-13.8832,  14.5708],
        [-18.1439,  16.6389],
        [ 13.1810, -15.0184],
        [  8.3016,  -9.6691],
        [ 11.6387, -11.4531],
        [  9.7337,  -9.2541],
        [  6.9418,  -6.5994],
        [-21.2682,  20.1542],
        [ 17.6656, -18.4507],
        [-12.8293,  10.9626],
        [-53.4153,  46.2686],
        [ 14.2847, -12.5007],
        [ 24.4070, -22.9794],
        [-27.4622,  24.2651],
        [ -4.8132,   7.1114],
        [-30.0794,  25.4495],
        [-46.9329,  44.0576],
        [-12.9876,  11.3068],
        [ -6.6569,   7.1027],
        [-24.7484,  22.5597],
        [ -0.7596,   1.9

        [ 17.6957, -14.6858]])
tensor(0.3169)
tensor([[-2.0201e+01,  1.9804e+01],
        [-2.9999e+01,  2.7615e+01],
        [-4.6985e+01,  4.4994e+01],
        [ 2.6749e+01, -2.4331e+01],
        [ 2.3691e+01, -2.3942e+01],
        [ 1.1634e+01, -1.3447e+01],
        [ 1.9540e+01, -1.6483e+01],
        [-2.8383e+01,  2.5907e+01],
        [-2.1542e+01,  2.1435e+01],
        [-1.0801e+01,  7.8704e+00],
        [ 1.7108e+01, -1.5829e+01],
        [ 7.0519e+00, -8.7200e+00],
        [-1.5380e+01,  1.5502e+01],
        [ 2.1275e+01, -1.6806e+01],
        [-2.3360e+01,  2.2628e+01],
        [-4.2534e+00,  4.3891e+00],
        [-2.6814e+01,  2.4529e+01],
        [-2.0268e+01,  2.0052e+01],
        [ 2.0168e+01, -1.9573e+01],
        [ 1.7503e+01, -1.4467e+01],
        [-1.7606e+01,  1.2108e+01],
        [ 3.0513e+01, -2.7649e+01],
        [ 1.2023e+01, -1.3857e+01],
        [-1.8827e+01,  1.7791e+01],
        [ 3.4451e+01, -3.4446e+01],
        [ 2.5275e+01, -2.2463e+01],
        [-8.5138e+

        [ 15.8248, -13.2289]])
tensor(0.0441)
tensor([[ 2.4172e+01, -2.4170e+01],
        [-2.2819e+01,  2.0417e+01],
        [ 2.3261e+01, -2.1249e+01],
        [ 2.0301e+01, -1.5642e+01],
        [-4.0008e+01,  3.6451e+01],
        [ 1.7412e+01, -2.0515e+01],
        [-8.7896e+00,  8.4679e+00],
        [-9.5727e+00,  8.6005e+00],
        [-8.7333e+00,  6.2335e+00],
        [-2.4712e+01,  2.1232e+01],
        [ 1.5753e+01, -1.4238e+01],
        [-2.0291e+01,  2.0028e+01],
        [ 1.1737e+01, -1.0522e+01],
        [ 9.8667e+00, -1.1704e+01],
        [-1.5961e+01,  1.6257e+01],
        [ 1.0320e+01, -9.0931e+00],
        [ 8.0410e+00, -8.1831e+00],
        [-2.0786e+01,  1.6337e+01],
        [ 3.0434e+01, -2.4098e+01],
        [ 1.8784e+01, -1.6915e+01],
        [-4.0439e-01, -2.4576e+00],
        [ 1.7039e+01, -1.5638e+01],
        [ 7.5150e+00, -1.1881e+01],
        [ 1.2863e+01, -1.0968e+01],
        [ 2.3081e+01, -2.4669e+01],
        [ 2.0903e+01, -1.5554e+01],
        [ 1.6584e+

In [13]:
test_aux(model_augment,test_data_)

tensor([[ -9.8938,  10.3419],
        [ 16.6709, -17.9201],
        [ 23.3979, -21.5222],
        [-10.7022,   9.4977],
        [ -2.4615,  -0.2242],
        [-38.9516,  35.3290],
        [-14.2416,  12.5090],
        [ 28.3499, -25.2572],
        [ 11.1100, -15.3099],
        [ 12.6073, -14.0770],
        [-13.7808,  15.1127],
        [ 22.4396, -20.7235],
        [-14.8469,  12.1844],
        [ 28.7829, -24.5809],
        [-11.6959,  11.6914],
        [-31.9046,  26.3116],
        [ -6.7414,   5.7983],
        [ 15.4814, -19.0008],
        [-17.5132,  16.6406],
        [ 13.7591, -13.7232],
        [-23.0742,  21.2181],
        [ 13.6384, -13.0415],
        [-26.8974,  24.6837],
        [ 14.9161, -15.5937],
        [ 30.5143, -30.4467],
        [  2.8102,  -1.7902],
        [-12.0056,  11.8404],
        [-13.1279,  12.4646],
        [ 25.1845, -25.7377],
        [  6.4624,  -3.7877],
        [-14.2167,  13.5468],
        [-37.1209,  32.6267],
        [ 15.8446, -14.9609],
        [-

        [-22.7796,  19.4505]])
tensor(0.2920)
tensor([[-18.0990,  19.5720],
        [-17.4219,  17.0000],
        [ 20.5130, -23.2692],
        [ 14.7195, -11.8729],
        [ 32.6331, -34.7847],
        [-28.4701,  25.4606],
        [-44.9104,  42.2885],
        [-20.5326,  14.2207],
        [ 11.3980, -15.3398],
        [-32.3062,  26.5775],
        [-18.7812,  19.5155],
        [ 26.6106, -22.0613],
        [ 23.4914, -20.5450],
        [-15.8236,  12.7666],
        [-61.6389,  55.4752],
        [ 16.1208, -15.7194],
        [-34.2458,  30.6366],
        [ 13.4292, -13.7170],
        [-26.1606,  24.8677],
        [  1.0512,  -0.3350],
        [-14.6101,  11.8230],
        [-34.1749,  32.4132],
        [ -2.0608,   2.2912],
        [-39.5106,  34.9852],
        [-20.9699,  18.0681],
        [-16.2658,  14.5014],
        [ 16.2682, -16.2295],
        [ 12.7339, -15.1743],
        [ 17.1224, -17.0041],
        [ -6.4898,   6.1493],
        [ 22.6115, -21.1935],
        [ -2.7821,   1.1

        [-14.3330,  10.5476]])
tensor(0.0021)
tensor([[ 33.6389, -28.3368],
        [ 23.7768, -22.6439],
        [ 15.8930, -16.0073],
        [-15.6808,  15.5834],
        [-29.6563,  25.7216],
        [  1.7775,  -3.9542],
        [ 10.7207, -12.6899],
        [-26.7399,  19.8209],
        [ 16.7916, -17.7586],
        [-10.0427,  12.2344],
        [-14.6969,  11.3264],
        [  6.1232,  -6.0544],
        [-14.7611,  12.0563],
        [-17.8330,  18.2585],
        [ 21.2263, -20.1556],
        [-36.6902,  33.8161],
        [  7.2757,  -8.9619],
        [ 11.2082,  -9.4659],
        [ 34.7014, -31.9162],
        [-26.3028,  26.1533],
        [-18.5298,  18.9650],
        [-13.5990,  13.1490],
        [-37.2799,  32.0622],
        [ 16.8544, -14.2432],
        [ 13.3711, -15.7879],
        [-24.3483,  25.3167],
        [ 20.4600, -23.2232],
        [-26.5314,  25.1710],
        [  2.2881,  -2.2004],
        [ 33.2998, -36.8558],
        [-25.5657,  23.8542],
        [ -2.2102,   0.5

        [-27.5311,  23.9040]])
tensor(0.1518)
tensor([[ 1.7646e+01, -1.9509e+01],
        [ 1.0345e+01, -1.0152e+01],
        [ 1.8038e+01, -1.6767e+01],
        [ 2.5938e+01, -2.2204e+01],
        [-4.0586e+01,  3.4987e+01],
        [ 1.9275e+01, -1.5550e+01],
        [ 2.4405e+01, -2.1819e+01],
        [-1.9557e+01,  1.5216e+01],
        [-3.2286e+01,  2.8593e+01],
        [-1.7827e+01,  1.6258e+01],
        [ 2.8814e+01, -2.4660e+01],
        [-3.1408e+01,  3.0320e+01],
        [ 2.0660e+01, -1.7312e+01],
        [ 2.5990e+00, -2.4794e+00],
        [ 2.0681e+01, -1.5511e+01],
        [-1.7985e+01,  1.5775e+01],
        [-3.7916e-01, -3.4175e-02],
        [ 2.1619e+01, -2.0920e+01],
        [-3.6345e+01,  3.4139e+01],
        [ 5.3059e+00, -5.9458e+00],
        [ 1.4725e+01, -1.6928e+01],
        [-1.9094e+01,  1.8482e+01],
        [-4.0766e+01,  3.5321e+01],
        [ 2.1210e+01, -1.7647e+01],
        [ 1.4546e+01, -1.3732e+01],
        [ 1.1826e+01, -9.4892e+00],
        [-4.9750e+

In [17]:
test_aux(model_augment,test_data_)


Test set | Loss: 3.5608 | Accuracy: 97% | # misclassified : 33/1000



In [7]:
model_augmented = LeNet_aux_sequential()
train_aux(model_augmented, train_data_,optimizer = optim.Adam, n_epochs = 40,eta = 0.001)
test_aux(model_augmented,test_data_)

Train Epoch: 0  | Loss 231.876205
Train Epoch: 1  | Loss 101.300598
Train Epoch: 2  | Loss 67.565048
Train Epoch: 3  | Loss 50.746643
Train Epoch: 4  | Loss 38.700336
Train Epoch: 5  | Loss 30.765127
Train Epoch: 6  | Loss 24.284731
Train Epoch: 7  | Loss 17.354771
Train Epoch: 8  | Loss 14.435579
Train Epoch: 9  | Loss 12.129961
Train Epoch: 10  | Loss 11.639895
Train Epoch: 11  | Loss 8.301371
Train Epoch: 12  | Loss 6.240178
Train Epoch: 13  | Loss 7.081873
Train Epoch: 14  | Loss 6.639845
Train Epoch: 15  | Loss 3.606985
Train Epoch: 16  | Loss 2.061212
Train Epoch: 17  | Loss 1.022062
Train Epoch: 18  | Loss 0.740710
Train Epoch: 19  | Loss 0.561961
Train Epoch: 20  | Loss 0.446838
Train Epoch: 21  | Loss 0.401379
Train Epoch: 22  | Loss 0.311708
Train Epoch: 23  | Loss 0.272904
Train Epoch: 24  | Loss 0.232222
Train Epoch: 25  | Loss 0.201090
Train Epoch: 26  | Loss 0.171800
Train Epoch: 27  | Loss 0.149026
Train Epoch: 28  | Loss 0.134019
Train Epoch: 29  | Loss 0.115600
Train E