# 3.5 图像分类数据集（Fashion-MNIST）

In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
import sys


## 获取数据集

In [7]:
def load_data_fashion_mnist(batch_size=256, resize=None, root='../../datasets'):
    trans = [
        transforms.ToTensor()
    ]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    transform = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) # shape (1, 28, 28), label = 10
    mnist_test = torchvision.datasets.FashionMNIST(root=root , train=False, download=True, transform=transform)
    return (
        DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4),
        DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)
    )

## Model

In [9]:
import torch.nn as nn
num_inputs = 28**2
num_outputs = 10

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(num_inputs, num_outputs)
)
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
num_epochs = 10
train_iter, test_iter = load_data_fashion_mnist(batch_size=256)
def train():
    for i in range(num_epochs):
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).mean()
            trainer.zero_grad()
            l.backward()
            trainer.step()
        print(f'epoch {i + 1}, loss {l.item():.4f}')
        
        with torch.no_grad():
            correct = 0
            total = 0
            for X, y in test_iter:
                y_hat = net(X)
                predicted = torch.argmax(y_hat, dim=1)
                total += y.size(0)
                correct += (predicted == y).sum().item()
            print(f'Accuracy on test set: {100 * correct / total:.2f}%')
train()

epoch 1, loss 0.6102
Accuracy on test set: 78.71%
epoch 2, loss 0.5434
Accuracy on test set: 81.06%
epoch 3, loss 0.4482
Accuracy on test set: 80.92%
epoch 4, loss 0.5705
Accuracy on test set: 77.96%
epoch 5, loss 0.5182
Accuracy on test set: 82.39%
epoch 6, loss 0.5090
Accuracy on test set: 82.32%
epoch 7, loss 0.4415
Accuracy on test set: 83.08%
epoch 8, loss 0.6517
Accuracy on test set: 83.14%
epoch 9, loss 0.5402
Accuracy on test set: 82.95%
epoch 10, loss 0.4625
Accuracy on test set: 83.48%
