In [2]:
import torch
import torch.nn as nn

from functools import partial

To create a clean code is mandatory to think about the main building blocks of the application, or of the network in our case.

In [8]:
class Conv2dAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding = (self.kernel_size[0] // 2,
                        self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size
        

def activation_func(activation):
    return nn.ModuleDict([
        ['relu', nn.ReLU(inplace = True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope = 0.01, inplace = True)],
        ['selu', nn.SELU(inplace = True)],
        ['none', nn.Identity()]
    ])

conv3x3 = partial(Conv2dAuto, kernel_size = 3, bias = False)

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation = 'relu'):
        super().__init__()
        self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
        self.blocks = nn.Identity()
        self.activate = activation_func(activation)
        self.shortcut = nn.Identity()
        
    def forward(self, x):
        residual = x
        if self.should_apply_shortcut: residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        x = self.activate(x)
        
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.out_channels
    
ResidualBlock(32, 64)

ResidualBlock(
  (blocks): Identity()
  (activate): ModuleDict(
    (relu): ReLU(inplace=True)
    (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True)
    (selu): SELU(inplace=True)
    (none): Identity()
  )
  (shortcut): Identity()
)

In [13]:
class ResNetResidualBlock(ResidualBlock):
    def __init__(self, in_channels, out_channels, expansion = 1, downsampling = 1, 
                 conv = conv3x3, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.expansion, self.downsampling, self.conv = expansion, downsampling, conv
        self.shortcut = nn.Sequential(
            nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size = 1,
                      stride = self.downsampling, bias = False),
            nn.BatchNorm2d(self.expanded_channels)
        ) if self.should_apply_shortcut else None
        
    @property
    def expanded_channels(self):
        return self.out_channels * self.expansion
    
    @property
    def should_apply_shortcut(self):
        return self.in_channels != self.expanded_channels

In [17]:
def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
    return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels))

In [18]:
class ResNetBasicBlock(ResNetResidualBlock):
    """
        Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
    """
    expansion = 1
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            conv_bn(self.in_channels, self.out_channels, conv = self.conv, bias = False, 
                    stride = self.downsampling),
            activation_func(self.activation),
            conv_bn(self.out_channels, self.expanded_channels, conv = self.conv, bias = False)
        )

In [19]:
dummy = torch.ones((1, 32, 224, 224))

block = ResNetBasicBlock(32, 64)
block(dummy).shape
print(block)

NotImplementedError: 