In [1]:
%reload_ext autoreload
%autoreload 1
import torch 
import sys
sys.path.append('..')
from torch import nn 
from torch.nn import functional as F
from torch import optim
from utils.loader import load
from utils.loader import PairSetMNIST
from torch.utils.data import Dataset, DataLoader

In [2]:
# load the dataset as a Dataset object
train_data = PairSetMNIST(train=True,swap_channel = True)
test_data  = PairSetMNIST(test=True)

In [92]:
class conv_block(nn.Module) :
    """
    basic 2d concolution with batch norm
    """
    
    def __init__(self, in_channels,out_channels,kernel_size = 1,stride =1, padding = 0) :
        super(conv_block,self).__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride ,padding)
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self,x) :
        x = self.bn(self.conv(x))
        return x

In [94]:
class Inception_block(nn.Module):
    """
    
    """
    def __init__(self,in_channels,out_channels):
        super(Inception_block, self).__init__()
        # 1x1 convolution
        self.conv1x1 = conv_block(in_channels,out_channels, kernel_size = 1)
        # 3x3 convolution factorized in 1x3 followed by 3x1
        self.conv3x3 = nn.Sequential(conv_block(in_channels,out_channels, kernel_size = 1),
                                     conv_block(out_channels, out_channels, kernel_size = (1,3), padding = (0,1)),
                                     conv_block(out_channels, out_channels, kernel_size = (3,1), padding = (1,0)))
        # 5x5 convolution factorized in two consecutive 3x3 implemented as above
        self.conv5x5 = nn.Sequential(conv_block(in_channels,out_channels, kernel_size = 1),
                                     conv_block(out_channels, out_channels, kernel_size = (1,3),padding =(0,1)),
                                     conv_block(out_channels, out_channels, kernel_size = (3,1), padding = (1,0)),
                                     conv_block(out_channels, out_channels, kernel_size = (1,3),padding=(0,1)),
                                     conv_block(out_channels, out_channels, kernel_size = (3,1),padding = (1,0)))
        # pooling layer 
        self.pool = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
                                  conv_block(in_channels, out_channels, kernel_size=1))

        
    def forward(self, x):
        
        # compute the four filter of the inception block :  Nx64x14x14
        scale1 = F.relu(self.conv1x1(x))
        scale2 = F.relu(self.conv3x3(x))
        scale3 = F.relu(self.conv5x5(x))
        scale4 = F.relu(self.pool(x))
        
        # concatenate layer for next result
        outputs = [scale1, scale2, scale3, scale4]
        # Nx256x14x14
        filter_cat = torch.cat(outputs,1)
        
        return filter_cat

In [95]:
class Auxiliary_loss (nn.Module) :
    
    def __init__(self,in_channels,nb_classes = 10):
        super(Auxiliary_loss, self).__init__()
        
        self.conv = conv_block(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, nb_classes)

    def forward(self, x):
        # aux: N x 256 x 14 x 14
        x = F.adaptive_avg_pool2d(x, (4, 4))
        # aux: N x 256 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        # N x 1024
        x = F.dropout(x, 0.7)
        # N x 1024
        x = self.fc2(x)
        # N x 10 (nb_classes)

        return x

In [96]:
class Google_Net (nn.Module) :
    
    def __init__(self,nb_classes = 10):
        super(Google_Net, self).__init__()
        
        # local response norm
        self.conv1 = conv_block(1, 32, kernel_size = 3, padding = (3 - 1)//2)
        #inception block
        self.inception = Inception_block(1,64)
        #auxiliary
        self.auxiliary = Auxiliary_loss(256)
        
        # weights for binary classification 
        self.fc1 = nn.Linear(20, 60)
        self.fc2 = nn.Linear(60, 90)
        self.fc3 = nn.Linear(90, 2)
        
    def forward(self, input_):
        
        # split the 2-channel input into two 1*14*14 images
        x = input_[:, 0, :, :].view(-1, 1, 14, 14)
        y = input_[:, 1, :, :].view(-1, 1, 14, 14)
        
        # inception blocks
        x = self.inception(x)
        y = self.inception(y)
        
        # auxiliary loss 
        x = self.auxiliary(x)
        y = self.auxiliary(y)
        
        # concatenate layers  
        z = torch.cat([x, y], 1)
        
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = self.fc3(z)
        
        
        return x,y,z
        

In [34]:
model = Google_Net()
x,y,z = model(train_data.x_)

torch.Size([2000, 10])
torch.Size([2000, 10])


In [99]:
##### train function ######

def train_aux (model, train_data, mini_batch_size=100, optimizer = optim.SGD,
                criterion = nn.CrossEntropyLoss(), n_epochs=50, eta=1e-1, lambda_l2 = 0, alpha=0.5, beta=0.5):
    
    
    """
    Train network with auxiliary loss + weight sharing
    
    """
    # create data loader
    train_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
    
    model.train()
    optimizer = optimizer(model.parameters(), lr = eta)
    
    for e in range(n_epochs):
        epoch_loss = 0
        
        for i, data in enumerate(train_loader, 0):
            
            input_, target_, classes_ = data
            class_1, class_2,out= model(input_)
            aux_loss1 = criterion(class_1, classes_[:,0])
            aux_loss2 = criterion(class_2, classes_[:,1])
            out_loss  = criterion(out, target_)
            net_loss = (alpha * (out_loss) + beta * (aux_loss1 + aux_loss2 ))
            epoch_loss += net_loss
            
            if lambda_l2 != 0:
                for p in model.parameters():
                    epoch_loss += lambda_l2 * p.pow(2).sum() # add an l2 penalty term to the loss 
            
            optimizer.zero_grad()
            net_loss.backward()
            optimizer.step()
            
        print('Train Epoch: {}  | Loss {:.6f}'.format(
                e, epoch_loss.item()))
        
#########################################################################################################################
#########################################################################################################################

### test function  ###

def test_aux(model, test_data, mini_batch_size=100, criterion = nn.CrossEntropyLoss()):
    
    """
    Test function to calculate prediction accuracy of a cnn with auxiliary loss
    
    """
    
    # create test laoder
    test_loader = DataLoader(test_data, batch_size=mini_batch_size, shuffle=True)
    
    model.eval()
    test_loss = 0
    nb_errors=0
    
    with torch.no_grad():
        
        for i, data in enumerate(test_loader, 0):
            input_, target_, classes_ = data
            
            _,_ ,output = model(input_) 
            batch_loss = criterion(output, target_)
            test_loss += batch_loss
            
            _, predicted_classes = output.max(1)
            for k in range(mini_batch_size):
                if target_[k] != predicted_classes[k]:
                    nb_errors = nb_errors + 1
                                   
             
        print('\nTest set | Loss: {:.4f} | Accuracy: {:.0f}% | # misclassified : {}/{}\n'.format(
        test_loss.item(), 100 * (len(test_data)-nb_errors)/len(test_data), nb_errors, len(test_data)))

In [98]:
model1 = Google_Net()
train_aux(model1, train_data)
test_aux(model1,test_data)

Train Epoch: 0  | Loss 29.767370
Train Epoch: 1  | Loss 13.516229
Train Epoch: 2  | Loss 8.503138
Train Epoch: 3  | Loss 6.766137
Train Epoch: 4  | Loss 6.159487
Train Epoch: 5  | Loss 4.913076
Train Epoch: 6  | Loss 4.067536
Train Epoch: 7  | Loss 3.638935
Train Epoch: 8  | Loss 3.654038
Train Epoch: 9  | Loss 3.028535
Train Epoch: 10  | Loss 2.505662
Train Epoch: 11  | Loss 2.409002
Train Epoch: 12  | Loss 2.404171
Train Epoch: 13  | Loss 1.905158
Train Epoch: 14  | Loss 1.695484
Train Epoch: 15  | Loss 1.579560
Train Epoch: 16  | Loss 1.456288
Train Epoch: 17  | Loss 1.149295
Train Epoch: 18  | Loss 1.258273
Train Epoch: 19  | Loss 1.108944
Train Epoch: 20  | Loss 0.973030
Train Epoch: 21  | Loss 0.992214
Train Epoch: 22  | Loss 0.861399
Train Epoch: 23  | Loss 1.920845
Train Epoch: 24  | Loss 1.003562
Train Epoch: 25  | Loss 0.763496
Train Epoch: 26  | Loss 0.588748
Train Epoch: 27  | Loss 0.613650
Train Epoch: 28  | Loss 0.558315
Train Epoch: 29  | Loss 0.466551
Train Epoch: 30  |