In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import utils
from random import randint
from functools import partial
# from dataclasses import dataclass
from collections import OrderedDict

device = torch.device("cuda")
criterion = nn.CrossEntropyLoss()

In [2]:
class Conv2Auto(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # dynamic padding base on the kernel size
        self.padding = (self.kernel_size[0]//2, self.kernel_size[1]//2)

conv3x3 = partial(Conv2Auto, kernel_size=3, bias=False)
conv = conv3x3(in_channels=32, out_channels=64)
print(conv)
del conv

Conv2Auto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)


### Create a dictionary with different activation functions

In [3]:
def activation_func(activation):
    return nn.ModuleDict([
        ['relu', nn.ReLU(inplace=True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
        ['none', nn.Identity()]
    ])[activation]

#### [Read ModuleDict] (https://towardsdatascience.com/pytorch-how-and-when-to-use-module-sequential-modulelist-and-moduledict-7a54597b5f17)

### Residual block

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation='relu'):
        super().__init__()
        self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
        self.blocks = nn.Identity()
        self.activate = activation_func(activation)
        self.shortcut = nn.Identity()

    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: 
            residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        x = self.activate(x)
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels

# Test the Residual block
dummy = torch.ones((1,1,1,1))
block = ResidualBlock(1, 64)
block(dummy)

tensor([[[[2.]]]])

In [6]:
class ResNetResidualBlock(ResidualBlock):
    def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, conv=conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.expansion, self.downsampling, self.conv = expansion, downsampling, conv
        self.shortcut = nn.Sequential(
            nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1, stride=self.downsampling, bias=False),
            nn.BatchNorm2d(self.expanded_channels)) if self.should_apply_shortcut else None

    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
        
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.expanded_channels

#Test ResNetResidualBlock
ResNetResidualBlock(32, 64)

ResNetResidualBlock(
  (blocks): Identity()
  (activate): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

### Basic Blocks

In [7]:
def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    """
    Handy function to stack one conv and batchnorm layer
    """
    return nn.Sequential(
        conv(in_channels, out_channels, *args, **kwargs),
        nn.BatchNorm2d(out_channels)
    )

In [11]:
class ResNetBasicBlock(ResNetResidualBlock):
    """
    Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
    """
    expansion = 1
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
            activation_func(self.activation),
            conv_bn(self.out_channels, self.expanded_channels, conv= self.conv, bias=False)
        )

# Testing the block
dummy = torch.ones((1,32,224, 224))
block = ResNetBasicBlock(32, 64)
block(dummy).shape
print(block)

ResNetBasicBlock(
  (blocks): Sequential(
    (0): Sequential(
      (0): Conv2Auto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ReLU(inplace=True)
    (2): Sequential(
      (0): Conv2Auto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (activate): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


### Bottle neck

In [13]:
class ResNetBottleNeckBlock(ResNetBasicBlock):
    """
    Bottle Neck block include 3 layers
    1. 1x1 for reducing
    2. 3x3 for increasing dimensions
    3. 1x1 for reducing
    """
    expansion = 4
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, expansion=4, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, self.conv, kernel_size=1),
            activation_func(self.activation),
            conv_bn(self.out_channels, self.out_channels, self.conv, kernel_size=3, stride=self.downsampling),
            activation_func(self.activation),
            conv_bn(self.out_channels, self.expanded_channels, self.conv, kernel_size=1)
        )

# Test the block
dummy = torch.ones((1,32,10,10))
block = ResNetBottleNeckBlock(32,64)
block(dummy).shape
print(block)

ResNetBottleNeckBlock(
  (blocks): Sequential(
    (0): Sequential(
      (0): Conv2Auto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ReLU(inplace=True)
    (2): Sequential(
      (0): Conv2Auto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): ReLU(inplace=True)
    (4): Sequential(
      (0): Conv2Auto(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (activate): ReLU(inplace=True)
  (shortcut): Sequential(
    (0): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


### ResNet Layer

In [20]:
class ResNetLayer(nn.Module):
    """
    ResNet layer composed by `n` blocks stacked one after the other
    """
    def __init__(self, in_channels, out_channels, block=ResNetBasicBlock, n=1, *args, **kwargs):
        super().__init__()
        # Perform downsampling directly by convolutional layers that have a stride of 2
        downsampling = 2 if in_channels != out_channels else 1
        
        self.blocks = nn.Sequential(
            block(in_channels, out_channels, *args, **kwargs, downsampling=downsampling), 
            *[block(out_channels*block.expansion, out_channels, downsampling=1, *args, **kwargs) for _ in range(n-1)]
            )

    def forward(self, x):
        x = self.blocks(x)
        return x

# Testing layer
dummy = torch.ones((1, 64, 48, 48))
layer = ResNetLayer(64, 128, block=ResNetBasicBlock, n=3)
layer(dummy).shape

torch.Size([1, 128, 24, 24])