# 3.9 多层感知机的从零开始实现

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


## 获取和读取数据

In [2]:
def load_data_fashion_mnist(batch_size=256, resize=None, root='../../datasets'):
    trans = [
        transforms.ToTensor()
    ]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    transform = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) # shape (1, 28, 28), label = 10
    mnist_test = torchvision.datasets.FashionMNIST(root=root , train=False, download=True, transform=transform)
    return (
        DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4),
        DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)
    )

## 定义模型

In [3]:
class MLP(nn.Module):
    def __init__(self, input_size=28, hidden_size=256, num_classes=10):
        super().__init__()
        blocks = [
            nn.Flatten(),
            nn.Linear(input_size**2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        ]
        self.net = nn.Sequential(
            *blocks
        )
        self.net.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=.01)
    def forward(self, x):
        return self.net(x)

In [5]:
train_loader, test_loader = load_data_fashion_mnist()
epochs = 10
loss_fn = nn.CrossEntropyLoss()
model = MLP()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(epochs):
    # training
    for X, y in train_loader:
        pred = model(X)
        l = loss_fn(pred, y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
    print(f"Epoch {epoch+1}, Loss: {l.item()}")

    # inference # model.eval()
    model.eval()
    total_correct = 0
    total_count = 0
    with torch.no_grad():
        for X, y in test_loader:
            pred = model(X)
            predicted_labels = pred.argmax(dim=1)
            total_correct += (predicted_labels == y).sum().item()
            total_count += y.size(0)
    print(f"Test Accuracy: {total_correct / total_count:.4f}")
            

Epoch 1, Loss: 0.5912665724754333
Test Accuracy: 0.8218
Epoch 2, Loss: 0.38847729563713074
Test Accuracy: 0.8370
Epoch 3, Loss: 0.4852389395236969
Test Accuracy: 0.8563
Epoch 4, Loss: 0.23695945739746094
Test Accuracy: 0.8565
Epoch 5, Loss: 0.29383352398872375
Test Accuracy: 0.8675
Epoch 6, Loss: 0.2861435115337372
Test Accuracy: 0.8702
Epoch 7, Loss: 0.36800023913383484
Test Accuracy: 0.8668
Epoch 8, Loss: 0.36018475890159607
Test Accuracy: 0.8735
Epoch 9, Loss: 0.27021268010139465
Test Accuracy: 0.8762
Epoch 10, Loss: 0.2968735694885254
Test Accuracy: 0.8779
