In [2]:

import torch
from torch import nn


##### Shortcut Projection lienar projection

In [3]:
class ShortcutProjection(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super(ShortcutProjection, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x: torch.Tensor):
        return self.bn(self.conv(x))

##### Residual Block

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.rl1 = nn.ReLU()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.rl2 = nn.ReLU()

        if( stride != 1 or in_channels != out_channels):
            self.shorcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            self.shorcut = nn.Identity()

        self.rl2 = nn.ReLU()

    def forward(self, x: torch.Tensor):

        shortcut = self.shorcut(x)
        x = self.rl1(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))

        return self.rl2(x + shortcut) 

In [None]:
class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels: int, bottlneck_channels: int, out_channels: int, stride: int):
        super(BottleneckResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(bottlneck_channels)
        self.rl1 = nn.ReLU()

        self.conv2 = nn.Conv2d(bottlneck_channels, bottlneck_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(bottlneck_channels)
        self.rl2 = nn.ReLU()

        self.conv3 = nn.Conv2d(bottlneck_channels, out_channels, kernel_size=1, stride=1)
        self.bn3 = nn.BatchNorm2d(out_channels)

        if( stride != 1 or in_channels != out_channels):
            self.shorcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            self.shorcut = nn.Identity()

        self.rl3 = nn.ReLU()

    def forward(self, x: torch.Tensor):

        shortcut = self.shorcut(x)

        x = self.rl1(self.bn1(self.conv1(x)))
        x = self.rl2(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))

        return self.rl2(x + shortcut)

##### ResNet MODEL

In [29]:
class ResNet(nn.Module):
    def __init__(self, n_blocks: list[int], n_channels: list[int], bottlenecks: list[int] | None, img_channels: int = 3, first_kernel_size: int = 7):
        super(ResNet, self).__init__()

        assert len(n_blocks) == len(n_channels)
        assert bottlenecks is None or len(bottlenecks) == len(n_channels)

        self.conv = nn.Conv2d(
            img_channels, 
            n_channels[0], 
            kernel_size=first_kernel_size, 
            stride=2, 
            padding=first_kernel_size // 2
        )

        self.bn = nn.BatchNorm2d(n_channels[0])
        blocks = []

        prev_channels = n_channels[0]

        for i, channels in enumerate(n_channels):
            stride = 2 if len(blocks) == 0 else 1

            if bottlenecks is None:
                blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
            else:
                blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels, stride=stride))
            
            prev_channels = channels

            for _ in range(n_blocks[i] - 1):
                if bottlenecks is None:
                    blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
                else:
                    blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels, stride=stride))
            
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor):

        x = self.bn(self.conv(x))
        x = self.blocks(x)
        x = x.view(x.shape[0], x.shape[1], -1)

        return x.mean(dim=-1)


