In [1]:
import torch
from torch import nn

# 1. 残差快

In [2]:
class Residual(nn.Module):
    """
        残差快：组成残差网络的最小单位
        in_channels: 输入通道数
        out_channels: 输出通道数
        stride: 步长
    """
    def __init__(self, in_channels, out_channels, stride=1):
        """
            初始化函数
        """
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        # 处理维度不匹配的情况
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride),
                                          nn.BatchNorm2d(num_features=out_channels))

    def forward(self, x):
        """
            前向传播
            返回f(x) + 短接
        """
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        return self.relu(out + self.shortcut(x))

# 2. 残差模块

In [3]:
class ResnetBlock(nn.Module):
    """
        残差模块：由多个残差块（Residual）堆叠而成
        num_channels: 每个残差块的输出通道数
        num_res: 残差块的数量
        first_block: 是否为网络的第一个残差模块
    """
    # def __init__(self, num_channels, num_res, first_block=False):
    #     super().__init__()
    #     layers = []

    #     for i in range(num_res):
    #         if i == 0 and not first_block:
    #             # 第一个残差快，但不是第一个残差模块时， 下采样+通道变化
    #             layers.append(Residual(in_channels=num_channels // 2, out_channels=num_channels, stride=2))
    #         else:
    #             layers.append(Residual(in_channels=num_channels, out_channels=num_channels))
    #     self.net = nn.Sequential(*layers)

    def __init__(self, in_channels, out_channels, num_res, first_block=False):
        super().__init__()
        layers = []
        for i in range(num_res):
            if i == 0:
                # 第一个残差块，输入输出通道可能不同
                stride = 1 if first_block else 2
                layers.append(Residual(in_channels, out_channels, stride=stride))
            else:
                # 后续残差块，输入输出通道一致
                layers.append(Residual(out_channels, out_channels, stride=1))
        self.net = nn.Sequential(*layers)
         
    def forward(self, x):
        return self.net(x)        

# 3. 残差网络

In [4]:
class ResNet(nn.Module):
    """
        残差模块：由输入层， 残差模块层(由多个残差模块组成)， 输出层组成。
        num_blocks: 残差模块内的残差快数量
        num_classes: 最终分类的数量
    """
    def __init__(self, num_blocks=[3, 4, 6, 3], num_classes=10):
        super().__init__()

        # 输入层
        self.input_layer = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3),
                                         nn.BatchNorm2d(num_features=64),
                                         nn.ReLU(),
                                         nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        # 残差模块
        self.resblock_layer =  nn.Sequential(ResnetBlock(in_channels=64, out_channels=64, num_res=num_blocks[0], first_block=True),
                                              ResnetBlock(in_channels=64, out_channels=128, num_res=num_blocks[1]),
                                              ResnetBlock(in_channels=128, out_channels=256, num_res=num_blocks[2]),
                                              ResnetBlock(in_channels=256, out_channels=512, num_res=num_blocks[3]))

        # 全局平均池化和全连接层
        self.output_layer =  nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                           nn.Flatten(start_dim=1, end_dim=-1),
                                           nn.Linear(in_features=512, out_features=num_classes))

    def forward(self, x):
        x = self.input_layer(x)
        x = self.resblock_layer(x)
        x = self.output_layer(x)
        return x

# 4. 数据读取

In [5]:
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# 定义数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 下载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 划分部分训练集用于验证
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 5. 模型训练与验证

In [6]:
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mynet = ResNet(num_blocks=[2, 2, 2, 2], num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mynet.parameters(), lr=0.01, momentum=0.0)

def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)
        total_correct += (output.argmax(1) == target).sum().item()
    avg_loss = total_loss / len(loader.dataset)
    avg_acc = total_correct / len(loader.dataset)
    return avg_loss, avg_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item() * data.size(0)
            total_correct += (output.argmax(1) == target).sum().item()
    avg_loss = total_loss / len(loader.dataset)
    avg_acc = total_correct / len(loader.dataset)
    return avg_loss, avg_acc

# 6. 训练与验证

In [7]:
for epoch in range(3):
    train_loss, train_acc = train(mynet, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(mynet, val_loader, criterion, device)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

Epoch 1: Train Loss=0.3896, Train Acc=0.9046, Val Loss=0.1039, Val Acc=0.9733
Epoch 2: Train Loss=0.0632, Train Acc=0.9837, Val Loss=0.0587, Val Acc=0.9847
Epoch 3: Train Loss=0.0387, Train Acc=0.9901, Val Loss=0.0425, Val Acc=0.9867


# 7. 测试集评估

In [8]:
test_loss, test_acc = evaluate(mynet, test_loader, criterion, device)
print(f"Test Loss={test_loss:.4f}, Test Acc={test_acc:.4f}")

Test Loss=0.0390, Test Acc=0.9884
