### Softmax回归框架实现

1. 定义取数函数
2. 定义模型
3. 初始化变量
4. 定义优化函数
5. 定义损失函数
6. 训练
7. 检测训练效果

In [6]:
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision

In [7]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.] * n
        
    def __getitem__(self, i):
        return self.data[i]
    
    def add(self, *args):
        self.data = [a + b for (a, b) in zip(self.data, args)]

In [8]:
def load_data(batch_size, isTrain=True):
    dataset = torchvision.datasets.FashionMNIST(root='../data',
                                                train=isTrain,
                                                transform=torchvision.transforms.ToTensor())
    data_loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=isTrain)
    for imgs, labels in data_loader:
        yield (imgs.reshape(batch_size, -1).type(torch.float32), labels)
# data_iter = load_data(10, True)
# imgs, lab = next(data_iter)
# imgs.shape

In [9]:
model = torch.nn.Sequential(torch.nn.Flatten() ,torch.nn.Linear(784, 10))
# model[1].weight = torch.nn.Parameter(torch.normal(0, 0.01, (10, 784)))
# model[1].bias = torch.nn.Parameter(torch.zeros(10).reshape(10, -1))
def init_weight(m):
    if type(m) == torch.nn.Linear:
        torch.nn.init.normal_(m.weight, std=0.01)
        torch.nn.init.zeros_(m.bias)
model.apply(init_weight)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
softmax = torch.nn.Softmax(dim=0)

In [10]:
epochs = 1
for epoch in range(epochs):
    data_iter = load_data(10, isTrain=True)
    accumulator = Accumulator(2)
    for imgs, labels in data_iter:
        optimizer.zero_grad()
        comp_labels = model.forward(imgs)
        l = loss(comp_labels, labels)
        l.backward()
        optimizer.step()
    data_iter_test = load_data(10, isTrain=False)
    for imgs, labels in data_iter_test:
        with torch.no_grad():
            comp_labels = torch.argmax(model.forward(imgs), axis=1)
            correct_num = (comp_labels == labels).type(labels.dtype).sum()
            accumulator.add(correct_num.detach().numpy(), len(labels))
    print('epoch', epoch, accumulator[0], accumulator[1])

epoch 0 8206.0 10000.0
