In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

### Residual Block

In [29]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, identity_downsample = None, stride = 1):
        super(ResidualBlock, self).__init__()
        
        self.expansion = 4
        self.relu = nn.ReLU()
        
        self.ResBlockLayer1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.ResBlockLayer2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.ResBlockLayer3 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0, bias = False),
            nn.BatchNorm2d(out_channels * self.expansion)
        )
        
        self.identity_downsample = identity_downsample
        self.stride = stride
    
    def forward(self, x):
        identity = x.clone()
        x = self.ResBlockLayer1(x)
        x = self.ResBlockLayer2(x)
        x = self.ResBlockLayer3(x)
        
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        
        x += identity
        x = self.relu(x)
        
        return x

### ResNet

In [24]:
class ResNet(nn.Module): # ex) layers = [3, 4, 6, 3] for ResNet-50
    def __init__(self, block, layers, image_channels, num_classes):
        super(ResNet, self).__init__()
        
        self.in_channels = 64
        
        # conv1 - NOT Residual Block
        self.conv1 = nn.Sequential(
            nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # ResNet Layers
        self.ResLayers = nn.Sequential(
            self._make_layer(block, layers[0], out_channels=64, stride = 1),
            self._make_layer(block, layers[1], out_channels=128, stride = 1),
            self._make_layer(block, layers[2], out_channels=256, stride = 1),
            self._make_layer(block, layers[3], out_channels=1024, stride = 1)
        )
        
        self.ResFinal = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(512 * 4, num_classes)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.ResLayers(x)
        x = self.ResFinal(x)
        return x

    # _make_layer function
    def _make_layer(self, block, num_residual_blocks, out_channels, stride):
        identity_downsample = None
        layers = []

        # stride != 1 means dimension reduction, we use expansion = 4, so if in == out * 4, identity mapping possible
        if stride != 1 or self.in_channels != out_channels * 4:
            identity_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * 4)
            )

        layers.append(ResidualBlock(self.in_channels, out_channels, identity_downsample, stride))
        self.in_channels = out_channels * 4

        for i in range(num_residual_blocks - 1):
            layers.append(ResidualBlock(self.in_channels, out_channels)) # No need identity_downsampling -> observe the architecture in the paper by He et al(2015)

        return nn.Sequential(*layers)
                

In [25]:
def ResNet50(img_channels = 3, num_classes = 1000):
    return ResNet(ResidualBlock, [3,4,6,3], img_channels, num_classes)

In [30]:
def test():
    net = ResNet50()
    x = torch.randn(4,3,224,224)
    y = net(x)

test()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x4096 and 2048x1000)