<a href="https://colab.research.google.com/github/tony3ynot/ResNet-for-ImageNet-and-CIFAR/blob/main/ResNet_(for_ImageNet%2C_CIFAR_10_100).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# 1. Original ResNet (designed for ImageNet classification)

In [4]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
                                      nn.BatchNorm2d(inner_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(inner_channels, inner_channels * self.expansion, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(inner_channels))
        self.projection = projection
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.residual(x)

        # projection shortcut & identity shortcut
        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x

        out = self.relu(residual + shortcut)
        return out


# Bottleneck for ResNet-50/101/152
# downsampling stride at 3x3 convolution, instead of the first 1x1 convolution (original paper)
# A.K.A ResNet V1.5 (https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch)
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
        super().__init__()

        # bottleneck structure
        self.residual = nn.Sequential(nn.Conv2d(in_channels, inner_channels, 1, bias=False),
                                      nn.BatchNorm2d(inner_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(inner_channels, inner_channels, 3, stride=stride, padding=1, bias=False),
                                      nn.BatchNorm2d(inner_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(inner_channels, inner_channels * self.expansion, 1, bias=False),
                                      nn.BatchNorm2d(inner_channels * self.expansion))
        self.projection = projection
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.residual(x)

        # projection shortcut & identity shortcut
        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x

        out = self.relu(residual + shortcut)
        return out


# ResNet Architecture (ImageNet)
class ResNet(nn.Module):
    def __init__(self, block, num_block_list, num_classes = 1000, zero_init_residual = True):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.stage1 = self.make_stage(block, 64, num_block_list[0], stride=1)
        self.stage2 = self.make_stage(block, 128, num_block_list[1], stride=2)
        self.stage3 = self.make_stage(block, 256, num_block_list[2], stride=2)
        self.stage4 = self.make_stage(block, 512, num_block_list[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        # zero-initialize the last BN in each residual branch
        # 0.2~0.3%p accuracy gain according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, block):
                    nn.init.constant_(m.residual[-1].weight, 0)

    def make_stage(self, block, inner_channels, num_blocks, stride=1):
        # projection shortcut when downsampling / dimension-matching
        if stride != 1 or self.in_channels != inner_channels * block.expansion:
            projection = nn.Sequential(
                nn.Conv2d(self.in_channels, inner_channels * block.expansion, 1, stride=stride, bias=False),
                nn.BatchNorm2d(inner_channels * block.expansion))
        # identity shortcut otherwise
        else:
            projection = None

        layers = []
        layers += [block(self.in_channels, inner_channels, stride, projection)] # projection only in the first block
        self.in_channels = inner_channels * block.expansion

        for _ in range(1, num_blocks):
            layers += [block(self.in_channels, inner_channels)]

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [5]:
# ResNet models
def resnet18(**kwargs):
    return ResNet(BasicBlock, [2,2,2,2], **kwargs)

def resnet34(**kwargs):
    return ResNet(BasicBlock, [3,4,6,3], **kwargs)

def resnet50(**kwargs):
    return ResNet(Bottleneck, [3,4,6,3], **kwargs)

def resnet101(**kwargs):
    return ResNet(Bottleneck, [3,4,23,3], **kwargs)

def resnet152(**kwargs):
    return ResNet(Bottleneck, [3,8,36,3], **kwargs)

# 2. Modified ResNet (designed for CIFAR-10/100 dataset)

In [7]:
# Only use Basic Block for CIFAR-10/100
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, projection = None):
        super().__init__()

        self.residual = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False),
                                      nn.BatchNorm2d(out_channels),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(out_channels))
        self.projection = projection
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = self.residual(x)

        if self.projection is not None:
            shortcut = self.projection(x)
        else:
            shortcut = x

        out = self.relu(residual + shortcut)
        return out


# ResNet Architecture (CIFAR-10)
class ResNet_c(nn.Module):
    def __init__(self, block, num_block_list, num_classes = 10): # num_classes = 100
        super().__init__()

        self.in_channels = 16

        # change on the first conv layer (output dim = 16, kernel size = 3)
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        # 3 stages
        # feature map size = 32x32x16
        self.stage1 = self.make_stage(block, 16, num_block_list[0], stride=1)
        # feature map size = 16x16x32
        self.stage2 = self.make_stage(block, 32, num_block_list[1], stride=2)
        # feature map size = 8x8x64
        self.stage3 = self.make_stage(block, 64, num_block_list[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def make_stage(self, block, out_channels, num_blocks, stride=1):
        # projection shortcut when downsampling / dimension-matching
        if stride != 1 or self.in_channels != out_channels:
            projection = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels))
        # identity shortcut otherwise
        else:
            projection = None

        layers = [block(self.in_channels, out_channels, stride, projection)] # projection only in the first block
        self.in_channels = out_channels

        for _ in range(1, num_blocks):
            layers += [block(self.in_channels, out_channels)]

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [8]:
# ResNet models (CIFAR-10)
def resnet20():
    return ResNet_c(BasicBlock, [3, 3, 3])

def resnet32():
    return ResNet_c(BasicBlock, [5, 5, 5])

def resnet44():
    return ResNet_c(BasicBlock, [7, 7, 7])

def resnet56():
    return ResNet_c(BasicBlock, [9, 9, 9])

def resnet110():
    return ResNet_c(BasicBlock, [18, 18, 18])

def resnet1202():
    return ResNet_c(BasicBlock, [200, 200, 200])