In [17]:
from d2l import torch as d2l
import torch
import torch.nn as nn
from torch.nn import functional as F

In [20]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)

In [21]:
blk = ResidualBlock(3, 3)
X = torch.rand(size=(4, 3, 6, 6))
blk(X).shape

torch.Size([4, 3, 6, 6])

In [22]:
blk = ResidualBlock(3, 6, use_1x1conv=True, stride=2)
blk(X).shape

torch.Size([4, 6, 3, 3])

tensor([[[[0.3192, 0.7543, 0.7524, 0.1057, 0.6375, 0.7714],
          [0.2758, 0.3655, 0.1091, 0.9191, 0.9257, 0.5517],
          [0.8287, 0.7907, 0.5358, 0.8027, 0.1735, 0.5176],
          [0.3264, 0.0596, 0.3561, 0.6728, 0.2772, 0.0255],
          [0.9977, 0.1335, 0.9940, 0.7991, 0.3089, 0.8546],
          [0.7190, 0.0806, 0.2935, 0.1088, 0.5587, 0.6426]],

         [[0.0505, 0.5286, 0.6906, 0.9192, 0.0605, 0.0035],
          [0.5606, 0.8227, 0.4698, 0.5525, 0.8571, 0.0031],
          [0.4166, 0.8401, 0.4291, 0.4696, 0.6857, 0.1742],
          [0.1667, 0.6807, 0.9049, 0.9334, 0.7008, 0.6355],
          [0.6085, 0.1185, 0.0343, 0.0119, 0.9306, 0.5655],
          [0.5955, 0.6872, 0.8875, 0.4815, 0.2036, 0.9485]],

         [[0.6015, 0.7476, 0.8075, 0.7672, 0.1897, 0.1044],
          [0.8670, 0.3114, 0.2989, 0.3729, 0.7604, 0.5015],
          [0.8457, 0.3998, 0.4948, 0.2822, 0.7239, 0.4228],
          [0.2850, 0.5908, 0.8189, 0.3541, 0.0798, 0.1406],
          [0.5396, 0.5243, 0.4080, 0