## 数据统计函数

In [1]:
import torch
import torch.nn as nn

In [1]:
def accuracy(y_pre, y):
    """训练集准确率，应用于分类问题"""
    if len(y_pre.shape) > 1 and y_pre.shape[1] > 1:
        y_pre = torch.argmax(y_pre, axis=1)  # 分类问题
    cmp = (y_pre == y)
    return float(cmp.sum())

class Accumulator:
    """For accumulating sums over `n` variables."""
    def __init__(self, n):
        """Defined in :numref:`sec_softmax_scratch`"""
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def evaluate_test(net, data_loader, loss):
    """测试数据集上进行评估"""
    metric = Accumulator(2)  # 测试数据集上的损失，准确率
    for X, y in data_loader:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(float(l.sum()), accuracy(out, y))
    return metric[0] / n_test, metric[1] / n_test

## 训练函数

In [3]:
def train(net, train_loader, loss, optimizer):
    if isinstance(net, nn.Module):
        net.train()
    metric = Accumulator(2)  # 后面要统计训练集损失，训练集准确率
    for X, y in train_loader:
        y_pre = net(X)
        l = loss(y_pre, y)
        if isinstance(optimizer, torch.optim.Optimizer):
            # Using PyTorch in-built optimizer & loss criterion
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        metric.add(float(l.sum()), accuracy(y_pre, y))
    return metric[0] / n_train, metric[1] / n_train