In [255]:
import torch
from torch import nn

from torchsummary import summary

In [256]:
class ShortcutProjection(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=stride
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
    
    def forward(self, x):
        return self.bn(self.conv(x))

In [257]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=3, stride=stride, padding=1
        )
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, stride=1, padding=1
        )
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(
                in_channels=in_channels, out_channels=out_channels,
                stride=stride
            )
        else:
            self.shortcut = nn.Identity()
        self.act2 = nn.ReLU()
    
    def forward(self, x):
        residual = self.shortcut(x)
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        return self.act2(out + residual)

In [258]:
class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels, bottleneck_channels, out_channels, stride):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=bottleneck_channels,
            kernel_size=1, stride=1
        )
        self.bn1 = nn.BatchNorm2d(num_features=bottleneck_channels)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=bottleneck_channels, out_channels=bottleneck_channels,
            kernel_size=3, stride=stride, padding=1
        )
        self.bn2 = nn.BatchNorm2d(num_features=bottleneck_channels)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(
            in_channels=bottleneck_channels, out_channels=out_channels,
            kernel_size=1, stride=1
        )
        self.bn3 = nn.BatchNorm2d(num_features=out_channels)
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(
                in_channels=in_channels, out_channels=out_channels,
                stride=stride
            )
        else:
            self.shortcut = nn.Identity()
        
        self.act3 = nn.ReLU()
    
    def forward(self, x):
        residual = self.shortcut(x)
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
                
        return self.act3(out + residual)

In [259]:
x2 = torch.randn(5, 3, 64, 64)

In [260]:
bottleneck = BottleneckResidualBlock(in_channels=3, bottleneck_channels=4, out_channels=6, stride=1)

In [261]:
bottleneck(x2).shape

torch.Size([5, 6, 64, 64])

### ResNet

In [262]:
# class ResNetBase(nn.Module):
#     def __init__(
#         self, n_blocks, n_channels, bottlenecks,
#         img_channels, first_kernel_size
#     ):
#         super().__init__()
#         assert len(n_blocks) == len(n_channels)
#         assert bottlenecks is None or len(bottlenecks) == len(n_channels)
        
#         self.conv = nn.Conv2d(
#             in_channels=img_channels, out_channels=n_channels[0],
#             kernel_size=first_kernel_size, stride=1, padding=first_kernel_size//2
#         )
#         self.bn = nn.BatchNorm2d(n_channels[0])
        
#         blocks = []
#         prev_channels = n_channels[0]
        
#         for i, channels in enumerate(n_channels):
#             stride = 2 if len(blocks) == 0 else 1
#             if bottleneck is None:
#                 blocks.append(ResidualBlock(
#                     in_channels=prev_channels, out_channels=channels,
#                     stride=stride
#                 ))
#             else:
#                 blocks.append(BottleneckResidualBlock(
#                     in_channels=prev_channels, bottleneck_channels=bottlenecks[i],
#                     out_channels=channels, stride=stride
#                 ))
            
#             prev_channels = channels
            
#             for _ in range(n_blocks[i] - 1):
#                 if bottlenecks is None:
#                     blocks.append(ResidualBlock(
#                         in_channels=channels, out_channels=channels,
#                         stride=1
#                     ))
#                 else:
#                     blocks.append(BottleneckResidualBlock(
#                         in_channels=channels, bottleneck_channels==bottlenecks[i],
#                         out_channels=channels, stride=1
#                     ))
        
#         self.blocks = nn.Sequential(*blocks)

In [303]:
class ResNetBase(nn.Module):
    def __init__(
        self, img_channels, n_blocks, n_channels, first_kernel_size
    ):
        super().__init__()
        assert len(n_blocks) == len(n_channels)
        
        self.conv = nn.Conv2d(
            in_channels=img_channels, out_channels=n_channels[0],
            kernel_size=first_kernel_size, stride=1, padding=first_kernel_size//2
        )
                
        self.bn = nn.BatchNorm2d(n_channels[0])
        
        blocks = []
        prev_channels = n_channels[0]
                
        for i, channels in enumerate(n_channels):
            stride = 2 if len(blocks) == 0 else 1
            
            blocks.append(ResidualBlock(
                in_channels=prev_channels, out_channels=channels,
                stride=stride
            ))
                        
            prev_channels = channels
            
            for _ in range(n_blocks[i] - 1):
                blocks.append(ResidualBlock(
                    in_channels=channels, out_channels=channels,
                    stride=1
                ))
        
        self.blocks = nn.Sequential(*blocks)
    
    def forward(self, x):
        x = self.bn(self.conv(x))        
        x = self.blocks(x)
        x = x.view(x.shape[0], x.shape[1], -1)
        
        return x.mean(dim=-1)

In the paper, there're one initial convolution block and 16 residual blocks

In [292]:
# n_blocks = [1] + [2]

# n_channels = [
#     64, # first convolution
#     # 64,
#     128
# ]

In [293]:
n_blocks = [1] + [2] * 16

In [294]:
n_channels = [
    64, # first convolution
    64, 64, 64,
    128, 128, 128, 128,
    256, 256, 256, 256, 256, 256,
    512, 512, 512
]

In [295]:
len(n_channels)

17

In [298]:
model = ResNetBase(
    img_channels=3, n_blocks=n_blocks, n_channels=n_channels,
    first_kernel_size=7
)

In [299]:
x = torch.randn(5, 3, 32, 32)

In [302]:
model(x)

tensor([[1.6366, 1.5169, 1.9354,  ..., 1.6094, 1.6022, 1.6405],
        [1.4903, 1.4202, 1.4570,  ..., 1.3805, 1.5191, 1.5782],
        [1.6366, 1.4834, 1.5600,  ..., 1.5906, 1.6244, 1.6091],
        [1.6598, 1.6788, 1.4578,  ..., 1.5197, 1.5401, 1.4592],
        [1.5593, 1.5807, 1.3840,  ..., 1.6151, 1.4956, 1.4736]],
       grad_fn=<MeanBackward1>)

In [288]:
summary(model.to('cuda'), (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           9,472
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 16, 16]           4,160
       BatchNorm2d-4           [-1, 64, 16, 16]             128
ShortcutProjection-5           [-1, 64, 16, 16]               0
            Conv2d-6           [-1, 64, 16, 16]          36,928
       BatchNorm2d-7           [-1, 64, 16, 16]             128
              ReLU-8           [-1, 64, 16, 16]               0
            Conv2d-9           [-1, 64, 16, 16]          36,928
      BatchNorm2d-10           [-1, 64, 16, 16]             128
             ReLU-11           [-1, 64, 16, 16]               0
    ResidualBlock-12           [-1, 64, 16, 16]               0
         Identity-13           [-1, 64, 16, 16]               0
           Conv2d-14           [-1, 64,