In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

print(f"PyTorch 版本: {torch.__version__}")
print(f"torchvision 版本: {torchvision.__version__}")

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) 
])

# 定义批次大小

BATCH_SIZE = 64

# 载入 MNIST 数据集

train_dataset = torchvision.datasets.MNIST(
    root='../foundation/data', train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
    root='../foundation/data', train=False, transform=transform, download=True
)

# 创建数据加载器
train_loader = DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
)

print(f"数据准备完毕。")
print(f"训练集批次数: {len(train_loader)}, 测试集批次数: {len(test_loader)}")

# 验证一个批次
inputs, labels = next(iter(train_loader))
print(f"一个批次的图像形状: {inputs.shape}")
print(f"一个批次的标签形状: {labels.shape}")

PyTorch 版本: 2.5.1
torchvision 版本: 0.20.1
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../foundation/data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ../foundation/data/MNIST/raw/train-images-idx3-ubyte.gz to ../foundation/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../foundation/data/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ../foundation/data/MNIST/raw/train-labels-idx1-ubyte.gz to ../foundation/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../foundation/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ../foundation/data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../foundation/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../foundation/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting ../foundation/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../foundation/data/MNIST/raw

数据准备完毕。
训练集批次数: 938, 测试集批次数: 157





一个批次的图像形状: torch.Size([64, 1, 28, 28])
一个批次的标签形状: torch.Size([64])


# 模型，损失函数和优化器

In [None]:
# 定义计算设备
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("使用 Apple Silicon GPU (MPS)")
else:
    device = torch.device("cpu")
    print("使用 CPU")

# 定义模型 f(x; theta) —— 一个简单的卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 28x28 -> 14x14
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2) # 14x14 -> 7x7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(), 
            nn.Linear(32*7*7, 128), # 32*7*7 = 1568
            nn.ReLU(),
            nn.Linear(128, 10) # 10个类别
        )

    def forward(self, x):
        x = self.conv_stack(x)
        x = self.classifier(x)
        return x # 返回 logits

# 定义模型，损失函数和优化器
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss() # 多分类交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam 优化器

print(f"模型,损失函数和优化器已定义。模型已在 {device} 上。")

使用 Apple Silicon GPU (MPS)
“三位一体”已定义。模型已在 mps 上。


# 训练与验证

In [None]:
# 定义训练的周期 (Epochs)
NUM_EPOCHS = 5

print(f"--- 开始训练，共 {NUM_EPOCHS} 个周期 ---")

for epoch in range(NUM_EPOCHS):

    # --- 训练阶段 ---
    model.train() 
    running_loss = 0.0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # 训练五步骤
        optimizer.zero_grad()      # 1. 清空梯度
        outputs = model(inputs)    # 2. 前向传播
        loss = criterion(outputs, labels) # 3. 计算损失
        loss.backward()            # 4. 反向传播
        optimizer.step()           # 5. 更新参数
        # ==========================

        running_loss += loss.item()
        if (i + 1) % 200 == 0: 
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], 批次 [{i+1}/{len(train_loader)}], 损失: {running_loss / 200:.4f}')
            running_loss = 0.0

    # --- 验证阶段 ---
    model.eval() 
    correct = 0
    total = 0

    with torch.no_grad(): # 关闭梯度计算
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'--- Epoch {epoch+1} 验证完毕 ---')
    print(f'测试集准确率: {accuracy:.2f} %')
    print(f'---------------------------------')

print('--- 训练完成 ---')

--- 开始训练，共 5 个周期 ---
Epoch [1/5], 批次 [200/938], 损失: 0.5381
Epoch [1/5], 批次 [400/938], 损失: 0.1353
Epoch [1/5], 批次 [600/938], 损失: 0.0953
Epoch [1/5], 批次 [800/938], 损失: 0.0751
--- Epoch 1 验证完毕 ---
测试集准确率: 98.19 %
---------------------------------
Epoch [2/5], 批次 [200/938], 损失: 0.0614
Epoch [2/5], 批次 [400/938], 损失: 0.0542
Epoch [2/5], 批次 [600/938], 损失: 0.0515
Epoch [2/5], 批次 [800/938], 损失: 0.0483
--- Epoch 2 验证完毕 ---
测试集准确率: 98.73 %
---------------------------------
Epoch [3/5], 批次 [200/938], 损失: 0.0459
Epoch [3/5], 批次 [400/938], 损失: 0.0339
Epoch [3/5], 批次 [600/938], 损失: 0.0412
Epoch [3/5], 批次 [800/938], 损失: 0.0350
--- Epoch 3 验证完毕 ---
测试集准确率: 98.59 %
---------------------------------
Epoch [4/5], 批次 [200/938], 损失: 0.0263
Epoch [4/5], 批次 [400/938], 损失: 0.0313
Epoch [4/5], 批次 [600/938], 损失: 0.0263
Epoch [4/5], 批次 [800/938], 损失: 0.0312
--- Epoch 4 验证完毕 ---
测试集准确率: 98.80 %
---------------------------------
Epoch [5/5], 批次 [200/938], 损失: 0.0234
Epoch [5/5], 批次 [400/938], 损失: 0.0238
Epoch [5/5]