In [20]:
import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

In [21]:
# pytorch 不会自动将二维张量转成一维张量，因此我们需要手动将其展平
# 这里我们使用 nn.Flatten 层来实现这一操作
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        # 均值默认是0
        nn.init.normal_(m.weight, std=0.01)
        
net.apply(init_weights)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=True)
)

In [22]:
loss = nn.CrossEntropyLoss()

In [23]:
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

In [24]:
# 因为 d2l.train_ch3 在当前 d2l 版本中不可用，手动实现训练与评估函数
def evaluate_accuracy(data_iter, net):
    device = next(net.parameters()).device
    net.eval()
    acc = 0.0
    n = 0
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            acc += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
    return acc / n

def train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    for epoch in range(num_epochs):
        net.train()
        train_loss = 0.0
        train_acc = 0.0
        n = 0
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            trainer.zero_grad()
            l.backward()
            trainer.step()
            train_loss += l.item() * y.shape[0]
            train_acc += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print(f'epoch {epoch+1}, loss {train_loss/n:.4f}, train acc {train_acc/n:.3f}, test acc {test_acc:.3f}')

In [None]:
num_epochs = 10
# d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

epoch 1, loss 0.7861, train acc 0.748, test acc 0.782
epoch 2, loss 0.5706, train acc 0.812, test acc 0.799
epoch 3, loss 0.5257, train acc 0.827, test acc 0.819
epoch 4, loss 0.5008, train acc 0.832, test acc 0.819
epoch 5, loss 0.4853, train acc 0.837, test acc 0.822
epoch 6, loss 0.4739, train acc 0.840, test acc 0.829
epoch 7, loss 0.4661, train acc 0.843, test acc 0.827
epoch 8, loss 0.4578, train acc 0.845, test acc 0.830
epoch 9, loss 0.4519, train acc 0.847, test acc 0.831
epoch 10, loss 0.4476, train acc 0.848, test acc 0.821
