In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义网络结构
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 512)  # 输入维度32*32*3，输出维度512
        self.fc2 = nn.Linear(512, 256)  # 输入维度512，输出维度256
        self.fc3 = nn.Linear(256, 10)  # 输入维度256，输出维度10

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)  # 展平操作，准备全连接
        x = F.relu(self.fc1(x))  # 第一层全连接 -> ReLU
        x = F.relu(self.fc2(x))  # 第二层全连接 -> ReLU
        x = self.fc3(x)  # 第三层全连接
        return x

# 加载数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

# 初始化网络和优化器
net = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练网络
for epoch in range(2):  # 多次遍历数据集
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:  # 每2000个批次打印一次
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

    # 在每个epoch结束后测试网络
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

print('Finished Training')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:26<00:00, 6317600.81it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified
[1,  2000] loss: 1.913
[1,  4000] loss: 1.708
[1,  6000] loss: 1.641
[1,  8000] loss: 1.584
[1, 10000] loss: 1.573
[1, 12000] loss: 1.513
Accuracy of the network on the 10000 test images: 46 %
[2,  2000] loss: 1.455
[2,  4000] loss: 1.434
[2,  6000] loss: 1.425
[2,  8000] loss: 1.400
[2, 10000] loss: 1.403
[2, 12000] loss: 1.419
Accuracy of the network on the 10000 test images: 50 %
Finished Training
