In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义Baseline网络
class BaselineNet(nn.Module):
    def __init__(self):
        super(BaselineNet, self).__init__()
        # 卷积层1：输入通道3，输出通道16，核大小3x3
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 池化层1
        
        # 卷积层2：输入通道16，输出通道32，核大小3x3
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 池化层2
        
        # 全连接层1：输入特征32*8*8，输出特征128
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.relu_fc1 = nn.ReLU()
        
        # 全连接层2：输入特征128，输出特征10（对应CIFAR-10的10个类别）
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = x.view(-1, 32 * 8 * 8)  # 展平操作
        x = self.fc1(x)
        x = self.relu_fc1(x)
        x = self.fc2(x)
        
        return x

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # 标准化
])

# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化Baseline网络
net = BaselineNet().to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 多分类交叉熵损失
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

# 训练过程
EPOCHS = 20
for epoch in range(EPOCHS):
    net.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 梯度清零
        optimizer.zero_grad()
        
        # 前向传播
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # 反向传播与优化
        loss.backward()
        optimizer.step()
        
        # 统计损失
        running_loss += loss.item()
        if i % 100 == 99:  # 每100个批次打印一次
            print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# 测试过程
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

Epoch 1, Batch 100, Loss: 1.999
Epoch 1, Batch 200, Loss: 1.592
Epoch 1, Batch 300, Loss: 1.423
Epoch 2, Batch 100, Loss: 1.219
Epoch 2, Batch 200, Loss: 1.185
Epoch 2, Batch 300, Loss: 1.123
Epoch 3, Batch 100, Loss: 1.007
Epoch 3, Batch 200, Loss: 0.994
Epoch 3, Batch 300, Loss: 0.967
Epoch 4, Batch 100, Loss: 0.871
Epoch 4, Batch 200, Loss: 0.871
Epoch 4, Batch 300, Loss: 0.879
Epoch 5, Batch 100, Loss: 0.742
Epoch 5, Batch 200, Loss: 0.791
Epoch 5, Batch 300, Loss: 0.783
Epoch 6, Batch 100, Loss: 0.663
Epoch 6, Batch 200, Loss: 0.677
Epoch 6, Batch 300, Loss: 0.705
Epoch 7, Batch 100, Loss: 0.574
Epoch 7, Batch 200, Loss: 0.604
Epoch 7, Batch 300, Loss: 0.616
Epoch 8, Batch 100, Loss: 0.490
Epoch 8, Batch 200, Loss: 0.517
Epoch 8, Batch 300, Loss: 0.540
Epoch 9, Batch 100, Loss: 0.423
Epoch 9, Batch 200, Loss: 0.455
Epoch 9, Batch 300, Loss: 0.463
Epoch 10, Batch 100, Loss: 0.348
Epoch 10, Batch 200, Loss: 0.384
Epoch 10, Batch 300, Loss: 0.406
Epoch 11, Batch 100, Loss: 0.299
Epoc