In [1]:
import os
import torch
import torchvision
from torch import nn
from torch.nn import  functional as F
from d2l import torch as d2l

In [2]:
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=(1, 1)):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

In [3]:
blk = Residual(1, 3, use_1x1conv=True)
X = torch.rand(4, 1, 512, 6)
Y = blk(X)
print(Y.shape)

torch.Size([4, 3, 512, 6])


In [4]:
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=(2, 1)))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

In [5]:
b1  = nn.Sequential(nn.Conv2d(1, 64, kernel_size=1, stride=1), nn.BatchNorm2d(64), nn.ReLU())
b2 = nn.Sequential(*resnet_block(64, 64, 2, True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

net = nn.Sequential(b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 11))

In [6]:
print(net)

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (1): Sequential(
    (0): Residual(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Residual(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): Sequential(
    (0): Resid

In [7]:
X = torch.rand(size=(32, 1, 512, 6))
for layer in net:
    print(layer.__class__.__name__)
    X = layer(X)
    print(layer.__class__.__name__, "output shape:\t", X.shape)

Sequential
Sequential output shape:	 torch.Size([32, 64, 512, 6])
Sequential
Sequential output shape:	 torch.Size([32, 64, 512, 6])
Sequential
Sequential output shape:	 torch.Size([32, 128, 256, 6])
Sequential
Sequential output shape:	 torch.Size([32, 256, 128, 6])
Sequential
Sequential output shape:	 torch.Size([32, 512, 64, 6])
AdaptiveAvgPool2d
AdaptiveAvgPool2d output shape:	 torch.Size([32, 512, 1, 1])
Flatten
Flatten output shape:	 torch.Size([32, 512])
Linear
Linear output shape:	 torch.Size([32, 11])
