In [17]:
import torch
import torch.nn as nn

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(28 * 28, 256)
    self.fc2 = nn.Linear(256, 10)

  def forward(self, x):
    x = x.view(-1, 28 * 28)
    x = torch.relu(self.fc1(x))
    x = self.fc2(x)
    return x
  
  def predict(self, x):
    with torch.no_grad():
      x = self.forward(x)
      return torch.argmax(x, 1)

In [19]:
from torchvision import datasets, transforms

# 加载训练数据集
train_dataset = datasets.MNIST(
    root='./data', train=True, download=True, transform=transforms.ToTensor())

# 加载测试数据集
test_dataset = datasets.MNIST(
    root='./data', train=False, download=True, transform=transforms.ToTensor())

In [18]:
from torch.utils.data import DataLoader

# 定义训练数据加载器
train_loader = DataLoader(
    dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=32, shuffle=False)

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
for epoch in range(5):
  for i, (images, labels) in enumerate(train_loader):
    optimizer.zero_grad()
    outputs = net(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    if (i + 1) % 100 == 0:
      print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f'
            % (epoch + 1, 5, i + 1, len(train_dataset) // 32, loss.item()))

Epoch [1/5], Iter [100/1875] Loss: 0.3530
Epoch [1/5], Iter [200/1875] Loss: 0.3198
Epoch [1/5], Iter [300/1875] Loss: 0.3670
Epoch [1/5], Iter [400/1875] Loss: 0.5181
Epoch [1/5], Iter [500/1875] Loss: 0.1216
Epoch [1/5], Iter [600/1875] Loss: 0.4395
Epoch [1/5], Iter [700/1875] Loss: 0.2626
Epoch [1/5], Iter [800/1875] Loss: 0.1678
Epoch [1/5], Iter [900/1875] Loss: 0.2031
Epoch [1/5], Iter [1000/1875] Loss: 0.0776
Epoch [1/5], Iter [1100/1875] Loss: 0.2535
Epoch [1/5], Iter [1200/1875] Loss: 0.0487
Epoch [1/5], Iter [1300/1875] Loss: 0.2448
Epoch [1/5], Iter [1400/1875] Loss: 0.3708
Epoch [1/5], Iter [1500/1875] Loss: 0.2410
Epoch [1/5], Iter [1600/1875] Loss: 0.1500
Epoch [1/5], Iter [1700/1875] Loss: 0.1091
Epoch [1/5], Iter [1800/1875] Loss: 0.0263
Epoch [2/5], Iter [100/1875] Loss: 0.0541
Epoch [2/5], Iter [200/1875] Loss: 0.0643
Epoch [2/5], Iter [300/1875] Loss: 0.0944
Epoch [2/5], Iter [400/1875] Loss: 0.0250
Epoch [2/5], Iter [500/1875] Loss: 0.1402
Epoch [2/5], Iter [600/18

In [22]:
# 保存后加载模型
torch.save(net.state_dict(), 'model.pth')
net = Net()
net.load_state_dict(torch.load('model.pth'))

In [31]:
# 测试模型
net.eval()  # 切换到评估模式
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    outputs = net(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  print('Accuracy of the network on the 10000 test images: %d %%' % (
      100 * correct / total))

Accuracy of the network on the 10000 test images: 97 %
