### 残差连接
  - 是一种将输入直接添加到输出的机制
  - 缓解深层网络中梯度消失的问题
  - 使信息可以在不经过所有层的情况下传递
  - 提高深层神经网络的训练效率

In [17]:
import torch
import torch.nn as nn

torch.manual_seed(42) # 设置随机种子，保证每次运行结果一致

<torch._C.Generator at 0x7c350855ae70>

In [18]:
# 定义残差块
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

        self.relu = nn.ReLU(inplace=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

In [19]:
# 构建网络模型，包含多个残差块
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.layer2 = ResidualBlock(64, 64)
        self.layer3 = ResidualBlock(64, 128, stride=2, downsample=nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(128)
        ))
        self.layer4 = ResidualBlock(128, 128)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [20]:
# 模拟输入
input_data = torch.randn(1, 3, 32, 32)  # 假设输入是32x32的RGB图像
model = ResNet(num_classes=10)

In [21]:
# 前向传播
output = model(input_data)
print(output)  # 输出模型的预测结果

tensor([[-0.0362, -0.4174, -0.8190,  0.2587,  1.0380,  0.6910,  0.5373, -0.6552,
         -0.0163,  0.0168]], grad_fn=<AddmmBackward0>)


In [22]:
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [23]:
# 模拟训练步骤
target = torch.tensor([3])  # 假设真实标签是类别3
optimizer.zero_grad()
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")

Loss: 2.259821891784668


In [24]:
# 参数更新情况
for name, param in model.layer2.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.grad}")

conv1.weight: tensor([[[[-8.1463e-03,  1.7736e-03,  4.1157e-03],
          [ 1.8000e-02,  1.5645e-03, -1.0918e-02],
          [ 3.9549e-03,  2.2518e-03,  6.2204e-04]],

         [[-6.8123e-03,  8.7411e-03,  9.3745e-03],
          [ 2.4924e-03,  6.8948e-03, -1.3253e-03],
          [-6.5935e-04, -2.4736e-03,  8.1399e-03]],

         [[-2.7409e-03,  1.7637e-03,  7.1153e-03],
          [-1.1372e-02,  2.9999e-03,  1.3397e-02],
          [-1.5648e-03, -4.6558e-03,  1.8985e-03]],

         ...,

         [[-1.0194e-03,  3.2688e-03,  5.0138e-03],
          [ 1.6567e-03, -6.7017e-03, -2.5416e-03],
          [ 1.6570e-03, -7.6416e-03,  4.3841e-03]],

         [[-4.3420e-03,  1.5501e-03, -1.7443e-03],
          [-1.2266e-02,  4.8938e-03,  4.2349e-03],
          [ 7.1624e-03, -1.1381e-03, -2.2394e-03]],

         [[-9.7668e-03,  8.9193e-03, -2.8143e-03],
          [ 9.9086e-03, -1.2890e-03, -3.7529e-03],
          [-8.8053e-03, -8.9718e-04,  9.8506e-03]]],


        [[[-7.0682e-03,  7.5782e-03,  1