# Model

In [1]:
import torch
torch.cuda.empty_cache()
torch.cuda.synchronize()

import torchvision

import torchvision.transforms as transforms

from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch.optim as optim

## ConvNet

In [2]:
class ConvNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2, batch_size=1):
        super().__init__()
        
        self.switches = []
        
        ## Layer 1
        self.layer1 = nn.Conv2d(in_channels=in_channels, out_channels=96, kernel_size=7, stride=2)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
        self.norm1 = nn.BatchNorm2d(num_features=96)
        
        ## Layer 2
        self.layer2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
        self.norm2 = nn.BatchNorm2d(num_features=256)
        
        ## Layer 3
        self.layer3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1)

        ## Layer 4
        self.layer4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1)
        
        ## Layer 5
        self.layer5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1)
        self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
        
        ## Layer 6
        self.layer6 = nn.Linear(9216, 4096)
        
        ## Layer 7
        self.layer7 = nn.Linear(4096, 4096)
        
        ## Output
        self.output = nn.Linear(4096, out_channels)
        
        
    def forward(self, x):
        
        print(f'x_0: {x.shape}')
        
        ## Layer - 1
        x = self.layer1(x)
        x = F.relu(x)
        
        print(f'x_1 : {x.shape}')
        
        x, indices1 = self.pool1(x)
        x = self.norm1(x)
        
        self.switches.append(indices1)
        
        print(f'x_pool_1: {x.shape}')
        
        ## Layer - 2
        x =  self.layer2(x)
        x = F.relu(x)
        
        print(f'x_2: {x.shape}')
        
        x, indices2 = self.pool2(x)
        x = self.norm2(x)
        
        self.switches.append(indices2)
        
        print(f'x_pool_2: {x.shape}')
        
        ## Layer - 3
        x = self.layer3(x)
        x = F.relu(x)
        
        print(f'x_3: {x.shape}')
        
        ## Layer - 4
        x = self.layer4(x)
        x = F.relu(x)
        
        print(f'x_4: {x.shape}')
        
        ## Layer - 5
        x = self.layer5(x)
        x = F.relu(x)
        
        print(f'x_5: {x.shape}')
        
        x, indices5 = self.pool5(x)
        
        self.switches.append(indices5)
        
        print(f'x_pool_5: {x.shape}')
        
        ## Flatten the tensor
        x = torch.flatten(x, 1)
        
        print(f'x_flat: {x.shape}')
        
        ## Layer - 6
        x = self.layer6(x)
        x = F.relu(x)
        
        ## Layer - 7
        x = self.layer7(x)
        x = F.relu(x)
        
        ## Output Layer
        x = self.output(x)
        x = F.softmax(x)
        
        return x

In [3]:
def test():
    x = torch.randn((5, 3, 244, 244))
    model = ConvNet(in_channels = 3, out_channels = 2)
    
    y = model(x)
    print('y.shape:', y.shape)
    print('x.shape:', x.shape)
    print()
    
test()

x_0: torch.Size([5, 3, 244, 244])
x_1 : torch.Size([5, 96, 119, 119])
x_pool_1: torch.Size([5, 96, 59, 59])
x_2: torch.Size([5, 256, 28, 28])
x_pool_2: torch.Size([5, 256, 13, 13])
x_3: torch.Size([5, 384, 11, 11])
x_4: torch.Size([5, 384, 9, 9])
x_5: torch.Size([5, 256, 7, 7])
x_pool_5: torch.Size([5, 256, 3, 3])
x_flat: torch.Size([5, 2304])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (5x2304 and 9216x4096)