In [1]:
import torch
from torch import nn

In [2]:
class IdentityShortcut(nn.Module):
    """
        构建一个 输入/输出维度相同的ResBlock，也可以叫 基本形式（Identity Shortcut）
        输入输出的channel不变
    """
    def __init__(self):
        super().__init__()
        # 卷积抽特征
        self.conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        """
            identity 
        """
        return self.conv(x) + x

In [3]:
# 模拟一个输入
X = torch.randn(1, 3, 64, 64)
# 实例化
identity_sc = IdentityShortcut()
# 前向传播
y_pred = identity_sc(X)
# 查看输入前后通道数，没有变化 -> 不变 就是Identity Shortcut的特点
print(y_pred.shape)

torch.Size([1, 3, 64, 64])


In [4]:
class ProjectionShortcut(nn.Module):
    """
        构建一个 输入/输出维度不同的ResBlock，也可以叫 维度变化形式（Projection Shortcut）
        输入输出的channel有变化
    """
    def __init__(self):
        super().__init__()
        # 正常 进行卷积抽特征
        self.conv = nn.Conv2d(in_channels=3, out_channels=18, kernel_size=3, stride=2, padding=1)
        # 调整通道
        self.shortcut = nn.Conv2d(in_channels=3, out_channels=18, kernel_size=1, stride=2, padding=0)

    def forward(self, x):
        """
            identity 
        """
        return self.conv(x) + self.shortcut(x)    

In [5]:
# 模拟一个输入
X = torch.randn(1, 3, 64, 64)
# 实例化
projection_sc = ProjectionShortcut()
# 前向传播
y_pred = projection_sc(X)
# 查看输入前后通道数，有变化 -> 变 就是Bottleneck Shortcut的特点
print(y_pred.shape)

torch.Size([1, 18, 32, 32])


### 基本 ResBlock - 整合以上两种情况

In [6]:
import torch.nn as nn
import torch

class BasicResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 处理维度不匹配的情况
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                                          nn.BatchNorm2d(out_channels))

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return nn.ReLU()(out + self.shortcut(x))


if __name__ == "__main__":
    # 假设输入通道为3，输出通道为16，步幅为1
    block = BasicResBlock(in_channels=3, out_channels=16, stride=1)
    # 构造一个batch size为4，3通道，32x32的输入
    x = torch.randn(4, 3, 32, 32)
    y = block(x)
    print("输出shape:", y.shape)

输出shape: torch.Size([4, 16, 32, 32])
