In [None]:
#导入模块与超参数设置
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 训练超参数
batch_size = 64
epochs = 10
learning_rate = 0.01
momentum = 0.9
log_interval = 100  # 每 100 个 batch 输出一次损失

# 数据归一化参数
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2023, 0.1994, 0.2010)

In [4]:
#数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

train_dataset = datasets.CIFAR10(root='./CIFAR10',
                                 train=True,
                                 download=True,
                                 transform=transform)

train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)

test_dataset = datasets.CIFAR10(root='./CIFAR10',
                                 train=False,
                                 download=True,
                                 transform=transform)
                
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=batch_size)

In [None]:
#定义模型

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # CIFAR10 输入: (batch, 3, 32, 32)
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 5 * 5)  # 展平
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Net(
  (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=800, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [6]:
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

In [7]:
#训练函数
def train(epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if batch_idx % log_interval == log_interval - 1:
            print(f'[Epoch {epoch + 1}, Batch {batch_idx + 1}] loss: {running_loss / log_interval:.3f}')
            running_loss = 0.0

In [8]:
#测试函数
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total

    print(f'Accuracy on test set: {accuracy:.2f}%')   

In [9]:
#开始训练
for epoch in range(epochs):
    train(epoch)
    test()

[Epoch 1, Batch 100] loss: 2.237
[Epoch 1, Batch 200] loss: 1.932
[Epoch 1, Batch 300] loss: 1.707
[Epoch 1, Batch 400] loss: 1.559
[Epoch 1, Batch 500] loss: 1.507
[Epoch 1, Batch 600] loss: 1.444
[Epoch 1, Batch 700] loss: 1.421
Accuracy on test set: 50.34%
[Epoch 2, Batch 100] loss: 1.322
[Epoch 2, Batch 200] loss: 1.297
[Epoch 2, Batch 300] loss: 1.258
[Epoch 2, Batch 400] loss: 1.229
[Epoch 2, Batch 500] loss: 1.201
[Epoch 2, Batch 600] loss: 1.178
[Epoch 2, Batch 700] loss: 1.162
Accuracy on test set: 57.59%
[Epoch 3, Batch 100] loss: 1.104
[Epoch 3, Batch 200] loss: 1.050
[Epoch 3, Batch 300] loss: 1.068
[Epoch 3, Batch 400] loss: 1.031
[Epoch 3, Batch 500] loss: 1.073
[Epoch 3, Batch 600] loss: 1.010
[Epoch 3, Batch 700] loss: 1.003
Accuracy on test set: 61.18%
[Epoch 4, Batch 100] loss: 0.937
[Epoch 4, Batch 200] loss: 0.934
[Epoch 4, Batch 300] loss: 0.916
[Epoch 4, Batch 400] loss: 0.906
[Epoch 4, Batch 500] loss: 0.935
[Epoch 4, Batch 600] loss: 0.926
[Epoch 4, Batch 700] l