In [None]:
import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

In [None]:
num_inputs, num_outputs = 784, 10
nums_hiddens = 256

W1 = nn.Parameter(torch.randn(num_inputs, nums_hiddens, requires_grad=True))
b1 = nn.Parameter(torch.zeros(nums_hiddens, requires_grad=True))

W2 = nn.Parameter(torch.randn(nums_hiddens, num_outputs, requires_grad=True))
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

parameters = [W1, b1, W2, b2]

In [None]:
def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

In [None]:
def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(torch.matmul(X, W1) + b1)      # H = relu(X @ W1 + b1)
    return torch.matmul(H, W2) + b2         # return H @ W2 + b2

loss = nn.CrossEntropyLoss()

In [None]:
import matplotlib.pyplot as plt
from IPython import display


class Animator:
    """在训练过程中动态绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None,
                 xlim=None, ylim=None):
        self.fig, self.ax = plt.subplots()
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.legend = legend
        self.xlim = xlim
        self.ylim = ylim
        self.X = []
        self.Y = [[] for _ in legend]

        self.ax.set_xlabel(xlabel)
        self.ax.set_ylabel(ylabel)
        if xlim:
            self.ax.set_xlim(xlim)
        if ylim:
            self.ax.set_ylim(ylim)
        if legend:
            self.ax.legend(legend)

    def add(self, x, y):
        self.X.append(x)
        for i, yi in enumerate(y):
            self.Y[i].append(yi)

        self.ax.cla()
        self.ax.set_xlabel(self.xlabel)
        self.ax.set_ylabel(self.ylabel)
        if self.xlim:
            self.ax.set_xlim(self.xlim)
        if self.ylim:
            self.ax.set_ylim(self.ylim)

        for i in range(len(self.Y)):
            self.ax.plot(self.X, self.Y[i])

        if self.legend:
            self.ax.legend(self.legend)

        display.clear_output(wait=True)
        display.display(self.fig)


In [None]:
def accuracy(y_hat, y):
    if y_hat.ndim > 1:
        y_hat = y_hat.argmax(dim=1)
    return float((y_hat == y).sum())


def evaluate_accuracy(net, data_iter):
    metric = [0.0, 0.0]
    with torch.no_grad():
        for X, y in data_iter:
            metric[0] += accuracy(net(X), y)
            metric[1] += y.numel()

    return metric[0] / metric[1]


In [None]:
def train_ch3(
    net,
    train_iter,
    test_iter,
    loss,
    num_epochs,
    optimizer
):
    animator = Animator(
        xlabel='epoch',
        ylabel='value',
        legend=['train loss', 'train acc', 'test acc'],
        xlim=[1, num_epochs]
    )

    for epoch in range(num_epochs):
        metric = [0.0, 0.0, 0.0]  
        # loss_sum, correct, num_samples

        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y)

            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            metric[0] += float(l) * y.numel()
            metric[1] += accuracy(y_hat, y)
            metric[2] += y.numel()

        train_loss = metric[0] / metric[2]
        train_acc = metric[1] / metric[2]
        test_acc = evaluate_accuracy(net, test_iter)

        animator.add(
            epoch + 1,
            (train_loss, train_acc, test_acc)
        )

    plt.show()


In [None]:
num_epochs = 3
lr = 0.1
trainer = torch.optim.SGD(parameters, lr=lr)
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)