In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import style

import torch 
from torch import nn

from torchsummary import summary

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [3]:
class block(nn.Module):
    def __init__(self, in_channels, c1, c2, c3, c4):
        '''
        c1, c2, c3, c4 - (int, tuple, tuple, int)
        are the respective channels for each path
        '''
        super(block, self).__init__()
        self.path1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=(1,1))
        self.path21 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=(1,1))
        self.path22 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=(3,3),
                               stride=(1,1), padding=(1,1))
        self.path31 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=(1,1))
        self.path32 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=(5,5),
                                                    padding=(2,2))
        self.pool = nn.MaxPool2d(kernel_size=(3,3), stride=(1,1), padding=(1,1))
        self.path4 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=(1,1))
        self.relu = nn.ReLU()
        
        
    def forward(self, x):
        op1 = self.relu(self.path1(x))
        op2 = self.relu(self.path21(x))
        op2 = self.relu(self.path22(op2))
        op3 = self.relu(self.path31(x))
        op3 = self.relu(self.path32(op3))
        op4 = self.pool(x)
        op4 = self.path4(op4)
        return torch.cat([op1, op2, op3, op4], axis=1) # axis=1 is channel axis

In [4]:
incep_block = block(in_channels=1, c1=64, c2=(96, 128), c3=(16, 32), c4=32).to(device)
summary(incep_block, (1,96, 96))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 96, 96]             128
              ReLU-2           [-1, 64, 96, 96]               0
            Conv2d-3           [-1, 96, 96, 96]             192
              ReLU-4           [-1, 96, 96, 96]               0
            Conv2d-5          [-1, 128, 96, 96]         110,720
              ReLU-6          [-1, 128, 96, 96]               0
            Conv2d-7           [-1, 16, 96, 96]              32
              ReLU-8           [-1, 16, 96, 96]               0
            Conv2d-9           [-1, 32, 96, 96]          12,832
             ReLU-10           [-1, 32, 96, 96]               0
        MaxPool2d-11            [-1, 1, 96, 96]               0
           Conv2d-12           [-1, 32, 96, 96]              64
Total params: 123,968
Trainable params: 123,968
Non-trainable params: 0
-------------------------------

In [21]:
class inception_net(nn.Module):
    def __init__(self, input_size, layers, num_classes):
        # layers will be a list of simply num_blocks, c1, c2, c3, c4
        in_channels, height, width = input_size
        super(inception_net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64,kernel_size=(7,7), 
                               stride=(2,2) ,padding=(3,3))
        self.pool = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64,kernel_size=(1,1))
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=192,kernel_size=(3,3),
                              stride=(1,1), padding=(1,1))
        
        in_channels = 192
        conv_layers = []
        for i in range(len(layers)):
            for j in range(len(layers[i])):
                c1, c2, c3, c4 = layers[i][j]
                conv_layers.append(block(in_channels, c1, c2, c3, c4))
                in_channels = c1 + c2[1] + c3[1] + c4
            if i==len(layers)-1: # last layer
                conv_layers.append(nn.AvgPool2d(kernel_size=(2,2)))
            else:
                conv_layers.append(self.pool)
        self.conv = nn.Sequential(*conv_layers)
        self.relu = nn.ReLU()
        
        #print("in_channels = ", in_channels)
        height = height//(64) # 5 pool layers
        width  = width//(64) # 5 pool layers
        self.fc = nn.Linear(in_features=in_channels*height*width, out_features=num_classes)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = self.relu(self.conv(x))
        x = nn.Flatten()(x)
        x = self.fc(x)
        return x

In [22]:
input_shape = (1, 96, 96)
layer1 = [(64, (96, 128), (16, 32), 32), (128, (128, 192), (32, 96), 64)]
layer2 = [(192, (96, 208), (16, 48), 64), (160, (112, 224), (24, 64), 64), (128, (128, 256), (24, 64), 64),
         (112, (144, 288), (32, 64), 64), (256, (160, 320), (32, 128), 128)]
layer3 = [(256, (160, 320), (32, 128), 128), (384, (192, 384), (48, 128), 128)]
layers = [layer1, layer2, layer3]
num_classes = 10
model = inception_net(input_shape, layers, num_classes).to(device)
summary(model, input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 48, 48]           3,200
              ReLU-2           [-1, 64, 48, 48]               0
         MaxPool2d-3           [-1, 64, 24, 24]               0
         MaxPool2d-4           [-1, 64, 24, 24]               0
            Conv2d-5           [-1, 64, 24, 24]           4,160
              ReLU-6           [-1, 64, 24, 24]               0
            Conv2d-7          [-1, 192, 24, 24]         110,784
              ReLU-8          [-1, 192, 24, 24]               0
         MaxPool2d-9          [-1, 192, 12, 12]               0
        MaxPool2d-10          [-1, 192, 12, 12]               0
           Conv2d-11           [-1, 64, 12, 12]          12,352
             ReLU-12           [-1, 64, 12, 12]               0
           Conv2d-13           [-1, 96, 12, 12]          18,528
             ReLU-14           [-1, 96,