In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
        

## Conv Block Module

In [None]:
#clas for conv block module
class ConvBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        
        #super function to call the constructor of the parent class
        super(ConvBlock, self).__init__()
        
        #conv layer
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        
        #batch normalization
        self.bn = nn.BatchNorm2d(out_channels)
        
        #relu activation function
        self.relu = nn.ReLU(inplace=True)
        
    
    def forward(self, x):
        #conv layer
        x = self.conv(x)
        
        #batch normalization
        x = self.bn(x)
        
        #relu activation function
        x = self.relu(x)
        

## Residual Block Module

In [None]:
class ResidualBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, first=False):
        
        super(ResidualBlock, self).__init__()
        
        #res channels
        res_channels = in_channels // 4
        stride = 1
        
        self.projection = in_channels != out_channels
        if self.projection:
            self.conv = nn.Conv2d(in_channels, out_channels, 1, 2, 0)
            stride = 2
            res_channels = in_channels // 2
        
        if first : 
            self.conv = ConvBlock(in_channels, out_channels, 1, 2, 0)
            stride = 1
            res_channels = in_channels 
        
        #conv layer
        self.conv1 = nn.Conv2d(in_channels, res_channels, 1, 1, 0)
        self.conv2 = nn.Conv2d(res_channels, res_channels, 3, stride, 1)
        self.conv3 = nn.Conv2d(res_channels, out_channels, 1, 1, 0)
        self.relu = nn.ReLU()
        
    
    def forward(self, x):
        
        f = self.relu(self.conv1(x))
        f = self.relu(self.conv2(x))
        f = self.conv3(x)
        
        if self.projection:
            x = self.conv(x)
            
        h = self.relu(torch.add(f, x))
        
        return h
    
    

    

## ResNet module 

In [None]:
class Resnet(nn.Module):
    
    def __init__(self, no_blocks, in_channels, classes=10):
        super(Resnet, self).__init__()
        
        out_features = [256, 512, 1024, 2048]
        
        self.blocks = nn.ModuleList([ResidualBlock(64, 256, first=True)])
        
        for i in range(1, len(out_features)):
            
            if i > 0: 
                self.blocks.append(ResidualBlock(out_features[i-1], out_features[i]))
                
            for _ in range(no_blocks[i] - 1):
                self.blocks.append(ResidualBlock(out_features[i], out_features[i]))
                
        
        #network conv blocks 
        self.conv1 = ConvBlock(in_channels, 64, 7, 2, 3)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(out_features[-1], classes)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        
        for block in self.blocks:
            x = block(x)
            
        x = self.avgpool(x)
        x = self.fc(x)
        
        return x