In [4]:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

### 自定义卷积块的使用
[结构链接](https://zh.d2l.ai/chapter_convolutional-modern/resnet.html "resnet结构链接")

In [5]:
class Residual(nn.Module):  #@save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides) # 降维改变通道数
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

### 下面我们来查看输入和输出形状一致的情况

In [6]:
blk = Residual(3,3) # 初始化参数 == 输入通道数 输出通道数
X = torch.rand(4, 3, 6, 6) # 随机 4层 3通道 宽高为6
Y = blk(X)
Y.shape

torch.Size([4, 3, 6, 6])

In [7]:
blk = Residual(3,6, use_1x1conv=True, strides=2)
blk(X).shape

torch.Size([4, 6, 3, 3])