# Resnet Implementation 
I have basically followed this [blog](https://towardsdatascience.com/residual-network-implementing-resnet-a7da63c7b278). The author has provided step by step, which made the understanding of the whole concept really easy. 

### Importing some basic libraries

In [7]:
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
%matplotlib inline

### Basic Block
Basically adding the same padding to Conv2d of pytroch. Unlike other frameworks, they do not provide us with the option of just adding a parameter. 

In [8]:
class Conv2dAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs) :
        super().__init__(*args, **kwargs)
        self.padding = (self.kernel_size[0]//2, self.kernel_size[1]//2)

conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)
conv = conv3x3(in_channels=32, out_channels=64)
print(conv)
del conv

Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)


### ModuleDict 
We create a dictionary with different activation functiosn, this will be handy later. 
The ModuleDict is comes very useful if we want to change the activation fucntions. Also the concept of ModuleDict is pretty cool. 

In [9]:
def activation_func(activation) : 
    return nn.ModuleDict([
        ['relu', nn.ReLU(inplace = True)], 
        ['leaky_relu', nn.LeakyReLU(negative_slope = 0.01, inplace = True)],
        ['selu', nn.SELU(inplace=True)],
        ['none', nn.Identity()]
    ])[activation]

### Residual Block 
This class just have a very basic structure. It just defines the basic structure of the residual block. We will step by step add functionalities to this block. 
nn.identity act as place holders for other modules. 

In [10]:
class ResidualBlock(nn.Module) :
    def __init__(self, in_channels, out_channels, activation = 'relu'):
        super().__init__()
        self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
        self.blocks = nn.Identity()
        self.activate = activation_func(activation)
        self.shortcut = nn.Identity()
    
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut : residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        x = self.activate(x)
        return x
    
    @property
    def should_apply_shortcut(self): 
        return self.in_channels != self.out_channels

dummy = torch.ones((1, 1, 1, 1))
block = ResidualBlock(1, 64)
block(dummy)
        
        

tensor([[[[2.]]]])

### ResNetResidualBlock 
This block is extension of basic structure of ResidualBlock. In this class we define the self.shortcut of the residual block. 

In [11]:
class ResNetResidualBlock(ResidualBlock) :
    def __init__(self, in_channels, out_channels, expansion=1, downsampling=1,conv=conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.expansion = expansion 
        self.downsampling = downsampling 
        self.conv = conv
        self.shortcut = nn.Sequential(
                            nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1, stride=self.downsampling, bias=False),
                            nn.BatchNorm2d(self.expanded_channels)) if self.should_apply_shortcut else None
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
        
    @property
    def should_apply_shortcut(self) :
        return self.in_channels != self.expanded_channels

ResNetResidualBlock(32, 64)

ResNetResidualBlock(
  (blocks): Identity()
  (activate): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

### ResNet Basic Block 
Now we extend ResNetResidualBlock to define self.blocks
Hence therefore completing creating a very basic residual block. 

In [12]:
def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    return nn.Sequential(conv(in_channels, out_channels,*args, **kwargs),
                        nn.BatchNorm2d(out_channels))

class ResNetBasicBlock(ResNetResidualBlock):

    expansion = 1
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
            activation_func(self.activation),
            conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),
        )
    
dummy = torch.ones((1, 32, 224, 224))
block = ResNetBasicBlock(32, 64)
block(dummy).shape
print(block)

ResNetBasicBlock(
  (blocks): Sequential(
    (0): Sequential(
      (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ReLU(inplace=True)
    (2): Sequential(
      (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (activate): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


### ResNetBottleNeckBlock 
ResNetBottleNeckBlock extends the ResNetResidualBlock to match the authors description of the bottleneck. 

In [13]:
class ResNetBottleNeckBlock(ResNetResidualBlock):
    expansion = 4
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, expansion=4, *args, **kwargs)
        self.blocks = nn.Sequential(
                    conv_bn(self.in_channels, self.out_channels, self.conv, kernel_size=1),
                    activation_func(self.activation),
                    conv_bn(self.out_channels, self.out_channels, self.conv, kernel_size=3, stride=self.downsampling),
                    activation_func(self.activation),
                    conv_bn(self.out_channels, self.expanded_channels, self.conv, kernel_size=1),
        )

In [14]:
dummy = torch.ones((1, 32, 10, 10))
block = ResNetBottleNeckBlock(32, 64)
block(dummy).shape
print(block)

ResNetBottleNeckBlock(
  (blocks): Sequential(
    (0): Sequential(
      (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ReLU(inplace=True)
    (2): Sequential(
      (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): ReLU(inplace=True)
    (4): Sequential(
      (0): Conv2dAuto(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (activate): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


### ResNetLayer 
Finally we are moving towards building a resnet layer. In the ResNetLayer we just stack some ResNetBasicBlocks to get a layer.

Doubt : Why the downsampling is defined so?

In [41]:
class ResNetLayer(nn.Module):
    def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
        super().__init__()
        downsampling = 2 if in_channels != out_channels else 1 #Why
        self.blocks = nn.Sequential(
                block(in_channels, out_channels, *args, **kwargs, downsampling=downsampling),
                *[block(out_channels*block.expansion, out_channels, downsampling=1, *args, **kwargs) for _ in range(n-1)]

        )
        
    def forward(self, x) :
        x = self.blocks(x)
        return x


dummy = torch.ones((1, 64, 48, 48))

layer = ResNetLayer(64, 128, block=ResNetBasicBlock, n=3)
print(layer)
layer(dummy).shape

ResNetLayer(
  (blocks): Sequential(
    (0): ResNetBasicBlock(
      (blocks): Sequential(
        (0): Sequential(
          (0): Conv2dAuto(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): ReLU(inplace=True)
        (2): Sequential(
          (0): Conv2dAuto(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (activate): ReLU(inplace=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResNetBasicBlock(
      (blocks): Sequential(
        (0): Sequential(
          (0): Conv2dAuto(128, 128, kernel_size=(3, 3), stride=(1, 1), padding

torch.Size([1, 128, 24, 24])

### Encoder 
Now the encoder is composed of multiple layers.

In [44]:
class ResNetEncoder(nn.Module):
    def __init__(self, in_channels=3, blocks_sizes=[64, 128, 256, 512], deepths=[2,2,2,2],
                activation='relu', block=ResNetBasicBlock, *args, **kwargs):
        super().__init__()
        self.blocks_sizes = blocks_sizes
        
        self.gate = nn.Sequential(
                    nn.Conv2d(in_channels, self.blocks_sizes[0], kernel_size=7, stride=2, padding=3, bias=False),
                    nn.BatchNorm2d(self.blocks_sizes[0]),
                    activation_func(activation),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                  )
        
        self.in_out_block_sizes = list(zip(blocks_sizes, blocks_sizes[1:]))
        self.blocks = nn.ModuleList([ 
            ResNetLayer(blocks_sizes[0], blocks_sizes[0], n=deepths[0], activation=activation, 
                        block=block,*args, **kwargs),
            *[ResNetLayer(in_channels * block.expansion, 
                          out_channels, n=n, activation=activation, 
                          block=block, *args, **kwargs) 
              for (in_channels, out_channels), n in zip(self.in_out_block_sizes, deepths[1:])]       
        ])
    
    def forward(self, x):
        x = self.gate(x)
        for block in self.blocks : 
            x = block(x)
        return x

### Decoder
The last component of the ResNet finally!!
Doubt: What is adaptive average pool?

In [45]:
class ResnetDecoder(nn.Module):
    def __init__(self, in_features, n_classes):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.decoder = nn.Linear(in_features, n_classes)
    
    def forward(self, x):
        x = self.avg(x)
        x = x.view(x.size(0), -1)
        x = self.decoder(x)
        return x

### ResNet 
Finally defining the ResNet!!

In [46]:
class ResNet(nn.Module):
    def __init__(self, in_channels, n_classes, *args, **kwargs):
        super().__init__()
        self.encoder = ResNetEncoder(in_channels, *args, **kwargs)
        self.decoder = ResnetDecoder(self.encoder.blocks[-1].blocks[-1].expanded_channels, n_classes)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [47]:
def resnet18(in_channels, n_classes, block=ResNetBasicBlock, *args, **kwargs):
    return ResNet(in_channels, n_classes, block=block, deepths=[2, 2, 2, 2], *args, **kwargs)

def resnet34(in_channels, n_classes, block=ResNetBasicBlock, *args, **kwargs):
    return ResNet(in_channels, n_classes, block=block, deepths=[3, 4, 6, 3], *args, **kwargs)

def resnet50(in_channels, n_classes, block=ResNetBottleNeckBlock, *args, **kwargs):
    return ResNet(in_channels, n_classes, block=block, deepths=[3, 4, 6, 3], *args, **kwargs)

def resnet101(in_channels, n_classes, block=ResNetBottleNeckBlock, *args, **kwargs):
    return ResNet(in_channels, n_classes, block=block, deepths=[3, 4, 23, 3], *args, **kwargs)

def resnet152(in_channels, n_classes, block=ResNetBottleNeckBlock, *args, **kwargs):
    return ResNet(in_channels, n_classes, block=block, deepths=[3, 8, 36, 3], *args, **kwargs)

In [59]:
res1 = ResNet(3, 2)
dummy = torch.ones((1, 3, 224, 224))
print(res1(dummy))

tensor([[-0.1159, -0.2839]], grad_fn=<AddmmBackward>)
