# ResNet Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init # for Kaiming Normal

## Custom Lambda Layer

In [None]:
class LambdaLayer(nn.Module):
    """
    This layer will be used to implement our skip connections
    """
    def __init__(self, lambd):
        """
        This layer takes a function lambd as an instance variable
        """
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        """
        This method allows the LambdaLayer to transform an input x with the lambd instance variable
        Note that layer(x) is the same as layer.forward(x)
        """
        return self.lambd(x)

## Basic Building Block
This BasicBlock design is based on Figure 2 in the original ResNet paper. 
- A block consists of two weight layers with a ReLU in between.
- This is followed by a skip connection summed with the output, and another ReLU.
- Each weight layer is a convolution layer followed by batch normalization

In [None]:
class BasicBlock(nn.Module):
    """
    This class implements the basic building blocks of ResNet.
    A building block has two methods:
    > __init__() which constructs the block itself
    > forward() which explains how inputs are processed by the block.
    """
    def __init__(self, in_channels, out_channels, stride):
        """
        in_channels: the number of planes/filters that this block takes
        out_channels: the number of planes/filters that this block outputs
        stride: used for downsampling, used only in the first weight layer, in the first BasicBlock of a ResNet layer
        """
        # Call the super class initializer
        super().__init__()

        # If we use a 3x3 kernel with stride = 1, and maintain in_channel = out_channel, then we need padding = 1
        # A fast formula for this is (F - 1) / 2 = (3 - 1) / 2 = 1
        # If stride = 2, then this will make out_channel = in_channel / 2
        
        # Start with the first weight layer
        # This is where downsampling MIGHT happen
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                               out_channels=out_channels, 
                               kernel_size=3, 
                               stride=stride,
                               padding=1,
                               bias=False)
        
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)

        # Then make the second weight layer
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                               out_channels=out_channels, 
                               kernel_size=3, 
                               stride=1,
                               padding=1,
                               bias=False)
        
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)

        # Finally, we implement the skip connection for option A (page 4, paragraph 1)
        # If there is downsampling, we need to adjust the shape of the input
        # Take every two pixels (since we're downsampling by 2 anyway), and add zero padding
        # Because the width and height are consistent now, we add out_channels // 4 to the front and back of input
        # Therefore, now the width, height, and number of channels are consistent
        if in_channels != out_channels or stride != 1:
            self.skip = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], 
                                    pad=(0, 0, 0, 0, out_channels // 4, out_channels // 4),
                                    mode="constant",
                                    value=0))
        # Otherwise, there is no downsampling in the first layer, we don't need to change the shape of the input
        else:
            self.skip = nn.Sequential() # Just forwards the input as is

    def forward(self, x):
        """
        Class method that implements the computation of a BasicBlock.
        """
        # First weight layer
        outputs = self.bn1(self.conv1(x))

        # Activation with ReLU
        outputs = F.relu(outputs)

        # Second weight layer and join the skip connection
        outputs = self.bn2(self.conv2(outputs))
        outputs += self.skip.forward(x)

        # Activation with ReLU
        outputs = F.relu(outputs)
        
        return outputs

## Kaiming Normal Initialization
The paper mentions that weight initialization follows that of "K. He, X. Zhang, S. Ren, and J. Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In ICCV, 2015."

How does `model.apply(_weights_init)` work?
> When you call `model.apply(_weights_init)`, PyTorch starts at the top-level model.
It then goes through every single submodule inside model, one by one.
For each submodule it encounters (let's call it m), it calls the function `_weights_init(m)`, passing that specific submodule m as the argument.

In [4]:
def _weights_init(m: nn.Module):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

## ResNet Class
Based on the original paper, at Section 4.2 CIFAR-10 and Analysis (Para. 1), the inputs are 32x32 RGB images (i.e. 3 filters/channels/planes) with per-pixel mean subtracted. Moreover, the chosen architecture is as follows. **Note that n is the number of BasicBlock objects in a ResNet layer.**
1. The first layer is a 3x3 kernel convolution, with feature (output) map size 32x32, with 16 channels. (This is `self.conv` and `self.bn`)
2. The next 2n layers has feature (output) map size 32x32, 16 channels. (We call this `self.layer1`)
3. The next 2n layers has feature (output) map size 16x16, 32 channels. (We call this `self.layer2`)
4. The final 2n layers has feature (output) map size 8x8, 64 channels. (We call this `self.layer3`)
5. The network ends with a global average pooling, a 10-way fully-connected layer, and softmax. (This is `self.linear`)

Why does a BasicBlock contribute 2n towards the total number of layers?
> It's because each BasicBlock comprises of two conv-bn pairs.

Why don't we need SoftMax in `forward()`?
> It's because we are using cross-entropy loss.

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10): # For CIFAR-10, we have 10 output classes
        super().__init__()

        # If we want to use a 3x3 kernel with stride = 1, and maintain a 32x32 output, then we need padding = 1
        # A fast formula for this is (F - 1) / 2 = (3 - 1) / 2 = 1
        self.conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(num_features=16) # num_features = out_channels from previous layer

        # This instance variable will be used for the ResNet layers
        # This is consistent with the above convolution and batch normalization output
        self.in_channels = 16
        
        # Stride = 1, because feature map (output) size of this layer is 32x32
        self.layer1 = self._make_layer(block, num_blocks[0], out_channels=16, stride=1)

        # Stride = 2, because feature map (output) size of this layer is 16x16
        self.layer2 = self._make_layer(block, num_blocks[1], out_channels=32, stride=2)

        # Stride = 2, because feature map (output) size of this layer is 8x8
        self.layer3 = self._make_layer(block, num_blocks[2], out_channels=64, stride=2)

        # After global average pooling, the 8x8x64 tensors will be squashed into a 64-dimensional vector
        # That vector will be the input to this linear transformation
        self.linear = nn.Linear(in_features=64, out_features=num_classes)

        # Initialize weights for convolutional and fully connected layers
        # This recursively goes through all the submodules
        self.apply(_weights_init)

    def _make_layer(self, block, num_blocks, out_channels, stride):
        """
        Constructs a ResNet layer comprising of building blocks.
        """
        # If a layer does downsampling, it's done in the first block
        # The other blocks have stride=1 so that the feature map sizes don't decrease anywhere else
        strides = [stride] + [1] * (num_blocks - 1)

        # Create the blocks
        blocks = []
        for stride in strides:
            # Example: input is 32x32, 16 channels, we want to output 16x16, 8 channels
            # We have 3 blocks, then Block 1 will have 16 channels as input, 8 channels as output, stride = 2
            # Block 2 will have 8 channels as input, and 8 channels as output, stride = 1
            # Block 3 will have 8 channels as input, and 8 channels as output, stride = 1
            blocks.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels

        # Stack the blocks into a sequential layer, using Python list unpacking
        return nn.Sequential(*blocks) 
        
    def forward(self, x):
        """
        Class method that implements the computation of a ResNet.
        """
        # Part 1
        output = F.relu(self.bn(self.conv(x)))

        # Part 2 - 4
        output = self.layer1(output)
        output = self.layer2(output)
        output = self.layer3(output)

        # Part 5
        # The previous passed output is [batch, 64, 8, 8]
        # Recall the shape [batch, channels, height, width]
        output = F.avg_pool2d(output, kernel_size=output.size()[3])

        # The pooling layer created a tensor of size [batch, 64, 1, 1], technically a 64-D vector
        # Now we flatten the tensor into a vector
        output = torch.flatten(output, 1) # flattens from idx 1 onwards

        # Throw it into the feedforward network
        output = self.linear(output)
        
        return output



## Factory Functions for Fast Construction

In [None]:
def resnet20(): # n = 3
    # The ResNet has a single conv layer to start with
    # Then here we specify each ResNet layer has 3 BasicBlocks
    # And one final fully connected layer
    # So the total is 1 + (2 x 3) x 3 + 1 = 20
    return ResNet(BasicBlock, [3, 3, 3])


def resnet32(): # n = 5
    # The ResNet has a single conv layer to start with
    # Then here we specify each ResNet layer has 5 BasicBlocks
    # And one final fully connected layer
    # So the total is 1 + (2 x 5) x 3 + 1 = 32
    return ResNet(BasicBlock, [5, 5, 5])


def resnet44(): # n = 7
    # The ResNet has a single conv layer to start with
    # Then here we specify each ResNet layer has 7 BasicBlocks
    # And one final fully connected layer
    # So the total is 1 + (2 x 7) x 3 + 1 = 44
    return ResNet(BasicBlock, [7, 7, 7])


def resnet56(): # n = 9
    # The ResNet has a single conv layer to start with
    # Then here we specify each ResNet layer has 9 BasicBlocks
    # And one final fully connected layer
    # So the total is 1 + (2 x 9) x 3 + 1 = 56
    return ResNet(BasicBlock, [9, 9, 9])


def resnet110(): # n = 18
    # The ResNet has a single conv layer to start with
    # Then here we specify each ResNet layer has 18 BasicBlocks
    # And one final fully connected layer
    # So the total is 1 + (2 x 18) x 3 + 1 = 110
    return ResNet(BasicBlock, [18, 18, 18])


def resnet1202(): # n = 200
    # The ResNet has a single conv layer to start with
    # Then here we specify each ResNet layer has 200 BasicBlocks
    # And one final fully connected layer
    # So the total is 1 + (2 x 200) x 3 + 1 = 1202
    return ResNet(BasicBlock, [200, 200, 200])