In [3]:
import torch
from torchvision.transforms.v2 import ToTensor
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import KMNIST

In [62]:
# 定义超参数
LR = 1e-1
epochs = 20
BATCH_SIZE = 100

# 数据集加载
train_data = KMNIST(root="./Kuzushiji_data", train=True, download=True, transform=ToTensor())
test_data = KMNIST(root="./Kuzushiji_data", train=False, download=True, transform=ToTensor())

# type_list = []
# for data in train_data:
#     type_list.append(data[1])
# print(type_list)

# 数据打包
train_dl = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

# 定义模型
model = nn.Sequential(
    nn.Linear(784, 64),
    nn.Sigmoid(),
    nn.Linear(64, 32),
    nn.Sigmoid(),
    nn.Linear(32, 10),
)

# 损失函数&优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

# 训练数据
for epoch in range(epochs):
    for data, target in train_dl:
        # 前向运算
        output = model(data.reshape(-1, 784))
        # 计算损失
        loss = loss_fn(output, target)
        # 反向传播
        optimizer.zero_grad() # 参数梯度清零
        loss.backward()
        optimizer.step()
    print(f'epoch: {epoch}, loss: {loss}')

epoch: 0, loss: 2.059077739715576
epoch: 1, loss: 1.5281038284301758
epoch: 2, loss: 1.2852203845977783
epoch: 3, loss: 0.8322193026542664
epoch: 4, loss: 0.6077675819396973
epoch: 5, loss: 0.8588521480560303
epoch: 6, loss: 0.5500332117080688
epoch: 7, loss: 0.5008435845375061
epoch: 8, loss: 0.274258017539978
epoch: 9, loss: 0.43217793107032776
epoch: 10, loss: 0.4323316812515259
epoch: 11, loss: 0.37454017996788025
epoch: 12, loss: 0.25164636969566345
epoch: 13, loss: 0.35695502161979675
epoch: 14, loss: 0.31820449233055115
epoch: 15, loss: 0.2749771177768707
epoch: 16, loss: 0.4822560250759125
epoch: 17, loss: 0.2647589147090912
epoch: 18, loss: 0.16048145294189453
epoch: 19, loss: 0.15815521776676178


In [63]:
test_dl = DataLoader(test_data, batch_size=BATCH_SIZE)

correct = 0
total = 0

with torch.no_grad():
    for data, target in test_dl:
        output = model(data.reshape(-1, 784))
        _,predict = torch.max(output, 1)
        total += target.size(0)
        correct += (predict == target).sum().item()
print(f'acc: {correct/total*100}%')

acc: 82.19999999999999%
