In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
import os

In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)

    def forward(self, x):
        x = F.tanh(self.conv1(x))
        x = self.pool1(x)
        x = F.tanh(self.conv2(x))
        x = self.pool2(x)

        x = x.view(-1, 16*5*5) # reshape
        x = F.tanh(self.fc1(x))
        output = F.softmax(self.fc2(x), dim=1) # 为什么要dim=1，因为是对每一行进行softmax
        return output

In [3]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图片转换成32*32
    transforms.ToTensor(),  # 将图片转换成PyTorch的Tensor
    transforms.Normalize((0.1307,), (0.3081,))  # 归一化处理
])

In [4]:
train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) # train_loader的形状为[64, 1, 32, 32]

In [5]:
test_dataset = datasets.MNIST(root='data', train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [6]:
weight_path = './LeNet5/lenet5.pth'

In [7]:
LeNet5 = Model()
if os.path.exists(weight_path):
    LeNet5.load_state_dict(torch.load(weight_path))
    print('success load weight')
else:
    print('weight not exist')

weight not exist


In [8]:
opt = optim.Adam(LeNet5.parameters())

In [10]:
loss_fn = nn.CrossEntropyLoss()

In [11]:
epoch = 1
while epoch <= 100:
    for batch_idx, (x, y) in enumerate(train_loader): #batch_idx是索引，x是数据，y是标签, 索引是从0开始的
        logits = LeNet5(x)
        loss = loss_fn(logits, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        if batch_idx % 100 == 0: # 每100个batch打印一次loss
            print('epoch: {}, batch_idx: {}, loss is: {}'.format(epoch, batch_idx, loss.item()))
        
        if epoch % 10 == 0:
            torch.save(LeNet5.state_dict(), weight_path)
    epoch += 1

epoch: 1, batch_idx: 0, loss is: 4.430777072906494
epoch: 1, batch_idx: 100, loss is: 3.754167318344116
epoch: 1, batch_idx: 200, loss is: 3.6368207931518555
epoch: 1, batch_idx: 300, loss is: 3.5305211544036865
epoch: 1, batch_idx: 400, loss is: 3.5200607776641846
epoch: 1, batch_idx: 500, loss is: 3.522352457046509
epoch: 1, batch_idx: 600, loss is: 3.4743878841400146
epoch: 1, batch_idx: 700, loss is: 3.4824447631835938
epoch: 1, batch_idx: 800, loss is: 3.5162723064422607
epoch: 1, batch_idx: 900, loss is: 3.501516342163086
epoch: 2, batch_idx: 0, loss is: 3.4997198581695557
epoch: 2, batch_idx: 100, loss is: 3.478194236755371
epoch: 2, batch_idx: 200, loss is: 3.512770175933838
epoch: 2, batch_idx: 300, loss is: 3.4863193035125732
epoch: 2, batch_idx: 400, loss is: 3.508161783218384
epoch: 2, batch_idx: 500, loss is: 3.46283221244812
epoch: 2, batch_idx: 600, loss is: 3.5059452056884766
epoch: 2, batch_idx: 700, loss is: 3.456531524658203
epoch: 2, batch_idx: 800, loss is: 3.47358

In [12]:
LeNet5.eval()  # 将模型设置为评估模式

test_loss = 0
correct = 0
total = 0

with torch.no_grad():  # 关闭梯度计算
    for data, target in test_loader:
        output = LeNet5(data)  # 获取模型的输出
        test_loss += loss_fn(output, target).item()  # 累计损失
        _, predicted = torch.max(output.data, 1)  # 获取预测结果
        total += target.size(0)  # 更新测试的总数
        correct += (predicted == target).sum().item()  # 更新正确预测的数量

test_loss /= len(test_loader.dataset)  # 计算平均损失
accuracy = correct / total  # 计算准确率

print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({100. * accuracy:.2f}%)')

Test set: Average loss: 0.0544, Accuracy: 9893/10000 (98.93%)
