# Deep Residual Learning for Image Recognition (ResNet)
This is a PyTorch implementation of the paper Deep Residual Learning for Image Recognition. 

paper : https://papers.labml.ai/paper/ecbad378ae7311eb9864394904658322

In [2]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# Linear projections for shortcut connection
当x与f的维度不同时，就不能使用恒等映射，而是应该使用线性投射而匹配维度

In [3]:
class ShortcutProjection(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self,  x: torch.Tensor):
        return self.bn(self.conv(x))


# torch.nn.Identity --> 占位用的，输入什么输出什么
恒等函数

# 残差块 residual mapping

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super().__init__()
        self.res = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            self.shortcut = nn.Identity()
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor):
        shortcut = self.shortcut(x)
        x = self.res(x)

        return self.act(x+shortcut)

# Bottleneck Residual Block

In [None]:
class BottleneckResidualBlock(nn.Module):
    def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):
        super().__init__()
        self.res = nn.Sequential(
            nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(),
            nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(),
            nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels)
        )
        if stride != 1 or in_channels != out_channels:
            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
        else:
            self.shortcut = nn.Identity()
        self.act = nn.ReLU()

        
    def forward(self, x: torch.Tensor):
        shortcut = self.shortcut(x)
        x = self.res(x)

        return self.act(x+shortcut)

# ResNet Model

In [1]:
class ResNetBase(nn.Module):
    def __init__(self, n_blocks, n_channels, bottlenecks=None,
                img_channels: int = 3, first_kernel_size: int = 7):
        super().__init__()
        assert len(n_blocks) == len(n_channels) # 模块和通道数不一致
        assert bottlenecks is None or len(bottlenecks) == len(n_channels)
        self.conv = nn.Conv2d(img_channels, n_channels[0],
                             kernel_size=first_kernel_size, stride=2, padding=first_kernel_size//2)
        self.bn = nn.BatchNorm2d(n_channels[0])
        blocks = []
        prev_channels = n_channels[0]
        for i, channels in enumerate(n_channels):
            stride = 2 if len(blocks) == 0 else 1
            if bottlenecks is None:
                blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
            else:
                blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels,
                                                      stride=stride))
            prev_channels = channels

            for _ in range(n_blocks[i] - 1):              
                if bottlenecks is None:
                    blocks.append(ResidualBlock(channels, channels, stride=1))
                else:
                    blocks.append(BottleneckResidualBlock(channels, bottlenecks[i], channels, stride=1))

        self.blocks = nn.Sequential(*blocks)
    
    def forward(self, x: torch.Tensor):
        x = self.bn(self.conv(x))
        x = self.blocks(x)
        x = x.view(x.shape[0], x.shape[1], -1)

        return x.mean(dim=-1)
