In [1]:
import torch
from torch import nn

In [8]:
class ShortcutProjection(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=1,
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
    
    def forward(self, x):
        return self.bn(self.conv(x))

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=3, stride=stride, padding=1
        )
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, stride=1, padding=1
        )
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(
                in_channels=in_channels, out_channels=out_channels,
                stride=stride
            )
        else:
            self.shortcut = nn.Identity()
        self.act2 = nn.ReLU()
    
    def forward(self, x):
        residual = self.shortcut(x)
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        return self.act(x + residual)

In [7]:
class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels, bottleneck_channels, out_channels, stride):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=bottleneck_channels,
            kernel_size=1, stride=1
        )
        self.bn1 = nn.BatchNorm2d(num_features=bottleneck_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=bottleneck_channels, out_channels=bottleneck_channels,
            kernel_size=3, stride=stride, padding=1
        )
        self.bn2 = nn.BatchNorm2d(num_features=bottleneck_channels)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(
            in_channels=bottleneck_channels, out_channels=out_channels,
            kernel_size=1, stride=1
        )
        self.bn3 = nn.BatchNorm2d(num_features=out_channels)
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(
                in_channels=in_channels, out_channels=out_channels,
                stride=stride
            )
        else:
            self.shortcut = nn.Identity()
        
        self.act3 = nn.ReLU()
    
    def forward(self, x):
        residual = self.shortcut(x)
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        out = self.bn2(self.conv2(out))
        
        return self.act3(x + residual)