In [1]:
#导入模块与超参数设置
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os

# 训练超参数
batch_size = 64
epochs = 100
lr = 0.001
momentum = 0.9
save_dir = "results/task2"

os.makedirs(save_dir, exist_ok=True)

In [2]:
#数据预处理
from datasets import CIFAR10Dataset

train_dataset = CIFAR10Dataset(root_dir='./DS/CIFAR10', train=True)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

test_dataset = CIFAR10Dataset(root_dir='./DS/CIFAR10', train=False)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

print(f"训练集: {len(train_dataset)} 张图片")
print(f"测试集: {len(test_dataset)} 张图片")
print(f"类别: {train_dataset.classes}")

训练集: 50000 张图片
测试集: 10000 张图片
类别: ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [3]:
#定义模型

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # CIFAR10 输入: (batch, 3, 32, 32)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

        self.dropout_conv = nn.Dropout(0.25)
        self.dropout_fc = nn.Dropout(0.5)
    
    def forward(self, x):
        # 第一组: 32x32 -> 16x16
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout_conv(x)

        # 第二组: 16x16 -> 8x8
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)
        x = self.dropout_conv(x)

        # 第三组: 8x8 -> 4x4
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool(x)
        x = self.dropout_conv(x)

        # 全连接层
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout_fc(x)
        x = F.relu(self.fc2(x))
        x = self.dropout_fc(x)
        x = self.fc3(x)
        return x

#准备模型、Loss、优化器
model = Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [4]:
#开始训练
print(f"开始训练！批次大小：{batch_size}, 共 {epochs} 轮\n")

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # 计算平均指标
    avg_loss = epoch_loss / len(train_loader)
    train_acc = 100 * correct / total

    # 测试集评估
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    test_acc = 100 * test_correct / test_total

    print(f"Epoch [{epoch+1:3d}/{epochs}] Loss: {avg_loss:.5f} | Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")

#保存模型权重
torch.save(model.state_dict(), f"{save_dir}/model_epoch_{epochs}.pth")
print(f"\n训练完成 :)")
print(f"模型权重已保存：{save_dir}/model_epoch_{epochs}.pth")

开始训练！批次大小：64, 共 100 轮

Epoch [  1/100] Loss: 2.30324 | Train Acc: 9.71% | Test Acc: 10.00%
Epoch [  2/100] Loss: 2.30268 | Train Acc: 9.97% | Test Acc: 11.57%
Epoch [  3/100] Loss: 2.30245 | Train Acc: 10.23% | Test Acc: 10.00%
Epoch [  4/100] Loss: 2.30196 | Train Acc: 10.65% | Test Acc: 13.62%
Epoch [  5/100] Loss: 2.29566 | Train Acc: 12.38% | Test Acc: 12.13%
Epoch [  6/100] Loss: 2.21499 | Train Acc: 17.61% | Test Acc: 23.82%
Epoch [  7/100] Loss: 2.10994 | Train Acc: 23.20% | Test Acc: 26.63%
Epoch [  8/100] Loss: 2.03112 | Train Acc: 25.06% | Test Acc: 28.21%
Epoch [  9/100] Loss: 1.93735 | Train Acc: 26.72% | Test Acc: 31.50%
Epoch [ 10/100] Loss: 1.85568 | Train Acc: 28.90% | Test Acc: 32.88%
Epoch [ 11/100] Loss: 1.79688 | Train Acc: 31.32% | Test Acc: 35.13%
Epoch [ 12/100] Loss: 1.73280 | Train Acc: 34.25% | Test Acc: 39.66%
Epoch [ 13/100] Loss: 1.67069 | Train Acc: 37.33% | Test Acc: 41.63%
Epoch [ 14/100] Loss: 1.62774 | Train Acc: 39.03% | Test Acc: 43.95%
Epoch [ 15/10

In [5]:
#可视化
import random
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np

model.eval()
correct_imgs = []  # (img_tensor, true_label, pred_label)
wrong_imgs = []

classes = train_dataset.classes

# CIFAR10 归一化参数（用于反归一化）
cifar10_mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
cifar10_std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)

        for i in range(inputs.size(0)):
            # 反归一化还原原图
            img = inputs[i].cpu() * cifar10_std + cifar10_mean
            img = torch.clamp(img, 0, 1)
            true_label = labels[i].item()
            pred_label = predicted[i].item()

            if pred_label == true_label and len(correct_imgs) < 50:
                correct_imgs.append((img, true_label, pred_label))
            elif pred_label != true_label and len(wrong_imgs) < 50:
                wrong_imgs.append((img, true_label, pred_label))

        if len(correct_imgs) >= 50 and len(wrong_imgs) >= 50:
            break

# 随机选取5张
random.seed(42)
correct_samples = random.sample(correct_imgs, min(5, len(correct_imgs)))
wrong_samples = random.sample(wrong_imgs, min(5, len(wrong_imgs)))

# 保存可视化结果：2行5列，每个子图单独放大渲染
fig, axes = plt.subplots(2, 5, figsize=(20, 9), dpi=300)

# 第一行：预测成功
for i, (img, true_label, pred_label) in enumerate(correct_samples):
    axes[0, i].imshow(img.permute(1, 2, 0).numpy(), interpolation='nearest')
    axes[0, i].set_title(f'{classes[true_label]}', color='green', fontsize=14, fontweight='bold')
    axes[0, i].axis('off')
axes[0, 0].set_ylabel('Correct', fontsize=16, color='green', fontweight='bold')

# 第二行：预测失败
for i, (img, true_label, pred_label) in enumerate(wrong_samples):
    axes[1, i].imshow(img.permute(1, 2, 0).numpy(), interpolation='nearest')
    axes[1, i].set_title(f'Pred: {classes[pred_label]}\nReal: {classes[true_label]}', color='red', fontsize=11, fontweight='bold')
    axes[1, i].axis('off')
axes[1, 0].set_ylabel('Wrong', fontsize=16, color='red', fontweight='bold')

plt.suptitle('CIFAR-10 Predictions', fontsize=20, fontweight='bold')
plt.tight_layout()
plt.savefig(f'{save_dir}/predictions.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"可视化结果已保存：{save_dir}/predictions.png")

可视化结果已保存：results/task2/predictions.png


  plt.show()
