# 3.9 多层感知机的从零开始实现

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


## 获取和读取数据

In [2]:
def load_data_fashion_mnist(batch_size=256, resize=None, root='../../datasets'):
    trans = [
        transforms.ToTensor()
    ]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    transform = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) # shape (1, 28, 28), label = 10
    mnist_test = torchvision.datasets.FashionMNIST(root=root , train=False, download=True, transform=transform)
    return (
        DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4),
        DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)
    )

## 3.9.2 定义模型参数

In [3]:
num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float)
b1 = torch.zeros(num_hiddens, dtype=torch.float)
W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)

params = [W1, b1, W2, b2]
for param in params:
    param.requires_grad_(requires_grad=True)

## 3.9.3 定义激活函数

In [4]:
def relu(X):
    return torch.max(input=X, other=torch.tensor(0.0))

## 3.9.4 定义模型

In [5]:
def net(X):
    X = X.view((-1, num_inputs))
    H = relu(torch.matmul(X, W1) + b1)
    return torch.matmul(H, W2) + b2

## 3.9.5 定义损失函数

In [6]:
loss = torch.nn.CrossEntropyLoss()

## 3.9.6 训练模型

In [7]:
num_epochs, lr = 5, 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

epoch 1, loss 0.0030, train acc 0.714, test acc 0.753
epoch 2, loss 0.0019, train acc 0.821, test acc 0.777
epoch 3, loss 0.0017, train acc 0.842, test acc 0.834
epoch 4, loss 0.0015, train acc 0.857, test acc 0.839
epoch 5, loss 0.0014, train acc 0.865, test acc 0.845
