In [61]:
from importlib.metadata import entry_points

import torch
import torch.nn as nn
from torch import functional as F
from d2l import torch as d2l

In [62]:
batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=24)

In [63]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.layer2(x)
        x = self.relu(x)
        return x


In [64]:
MLP = nn.Sequential(nn.Flatten(),
                    Model(24 * 24, 128, 20),
                    nn.Linear(20, 10))

In [65]:
X = torch.randn(batch_size, 24 * 24)

for layer in MLP:
    X = layer(X)
    print(layer.__class__.__name__, "\t", X.shape)

Flatten 	 torch.Size([128, 576])
Model 	 torch.Size([128, 20])
Linear 	 torch.Size([128, 10])


In [66]:
loss = nn.CrossEntropyLoss()

In [67]:
trainer = torch.optim.SGD(MLP.parameters(), lr=0.01)

In [68]:
epochs = 10

for epoch in range(epochs):
    for X, y in train_iter:
        y_hat = MLP(X)
        l = loss(y_hat, y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    train_feature, train_label = next(iter(train_iter))
    print("loss:", loss(MLP(train_feature), train_label).item(), "accuracy: ")

loss: tensor(1.4840, grad_fn=<NllLossBackward0>)
loss: tensor(0.9873, grad_fn=<NllLossBackward0>)
loss: tensor(0.8792, grad_fn=<NllLossBackward0>)
loss: tensor(0.7267, grad_fn=<NllLossBackward0>)
loss: tensor(0.5463, grad_fn=<NllLossBackward0>)
loss: tensor(0.5239, grad_fn=<NllLossBackward0>)
loss: tensor(0.5120, grad_fn=<NllLossBackward0>)
loss: tensor(0.6463, grad_fn=<NllLossBackward0>)
loss: tensor(0.6278, grad_fn=<NllLossBackward0>)
loss: tensor(0.5028, grad_fn=<NllLossBackward0>)


In [74]:
logits = MLP(train_feature)               # [batch, num_classes]
pred = torch.argmax(logits, dim=1)
pred

tensor([0, 9, 7, 5, 0, 0, 3, 4, 1, 8, 8, 4, 0, 1, 6, 6, 4, 1, 7, 5, 6, 1, 3, 1,
        8, 8, 9, 3, 8, 9, 3, 8, 1, 2, 0, 8, 3, 5, 2, 5, 0, 1, 0, 3, 5, 5, 8, 7,
        3, 4, 2, 5, 5, 3, 0, 0, 9, 1, 6, 8, 0, 7, 4, 0, 8, 1, 5, 7, 0, 3, 9, 0,
        0, 6, 5, 4, 8, 9, 2, 4, 2, 6, 0, 3, 7, 6, 9, 6, 9, 3, 7, 1, 6, 6, 5, 9,
        5, 4, 6, 9, 1, 9, 9, 1, 7, 1, 1, 7, 1, 5, 8, 5, 8, 3, 3, 3, 4, 3, 1, 6,
        9, 9, 2, 2, 7, 2, 9, 9])

In [77]:
correct = (pred == train_label).sum().item()        # 正确数
correct, train_label.shape

(101, torch.Size([128]))

In [78]:
acc = correct / train_label.size(0)                 # 准确率

In [79]:
print("accuracy:", acc)

accuracy: 0.7890625
