In [183]:
# 导入包
import torch
import torch.nn as nn
from torchvision.datasets import KMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor


In [184]:
# 定义超参数
LR = 0.5 # 学习率
BATCH_SIZE = 16 # 批大小
EPOCHS = 15 # 训练轮数

In [185]:
# 数据集加载
train_data = KMNIST(root='./data', train=True, download=True, transform=ToTensor()) # 训练集
test_data = KMNIST(root='./data', train=False, download=True, transform=ToTensor()) # 测试集
train_dl = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) # 训练集加载器
test_dl = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False) # 测试集加载器

In [186]:
# print(train_data[0][0])
# print(train_data[0][1])
# print(set(b for a, b in train_data))

In [187]:
# 定义模型
model = nn.Sequential(
    nn.Linear(28*28, 128),
    nn.Sigmoid(),
    nn.Linear(128, 64),
    nn.Sigmoid(),
    nn.Linear(64, 10)
)

In [188]:
# 损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=LR)  # 梯度下降优化器

In [189]:
for epoch in range(EPOCHS): # 训练 epochs 轮
    # 提取训练数据
    for data, target in train_dl:
        output = model(data.reshape(-1, 28*28)) # 前向传播，得到模型输出
        loss = loss_fn(output, target) # 计算损失
        optimizer.zero_grad() # 清空梯度
        loss.backward() # 反向传播，计算梯度
        optimizer.step() # 更新参数
    print(f"Epoch {epoch}, Loss: {loss.item()}") # 打印损失

Epoch 0, Loss: 0.5514621138572693
Epoch 1, Loss: 0.3386995196342468
Epoch 2, Loss: 0.22018533945083618
Epoch 3, Loss: 0.2274157702922821
Epoch 4, Loss: 0.18193528056144714
Epoch 5, Loss: 0.06758689135313034
Epoch 6, Loss: 0.21472486853599548
Epoch 7, Loss: 0.08726271986961365
Epoch 8, Loss: 0.05774590000510216
Epoch 9, Loss: 0.30983322858810425
Epoch 10, Loss: 0.026087703183293343
Epoch 11, Loss: 0.002676398493349552
Epoch 12, Loss: 0.003774170530959964
Epoch 13, Loss: 0.0019305594032630324
Epoch 14, Loss: 0.0005274285795167089


In [190]:
# 测试训练成果
correct = 0
total = 0
with torch.no_grad(): # 不计算梯度
    for images, labels in test_dl:
        outputs = model(images.reshape(-1, 28*28))
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0) # 总样本数
        correct += (predicted == labels).sum().item() # 正确样本数

print(f'Accuracy: {100 * correct // total} %')

Accuracy: 89 %


|序号|学习率|迭代次数|数据集大小|隐藏层层数|神经元数量|损失值|训练时间|测试成功率|
|--|--|--|--|--|--|--|--|--|
|1|0.001|10|64|1|64|1.9732621908187866|54s|50%|
|2|0.01|10|64|1|64|0.4932018518447876|55s|65%|
|3|0.1|10|64|1|64|0.34229186177253723|54s|80%|
|4|1|10|64|1|64|0.16985097527503967|54s|85%|
|5|10|10|64|1|64|0.5860196948051453|54s|66%|
|6|0.1|20|64|1|64|0.027062786743044853|1m56s|84%|
|7|0.1|10|128|1|64|0.5023661255836487|60s|76%|
|8|0.1|10|32|1|64|0.5465865731239319|60s|84%|
|9|0.1|10|16|1|64|0.07451552897691727|1m12s|86%|
|10|0.1|10|16|1|128|0.0568353570997715|1m12s|88%|
|11|0.1|10|16|1|256|0.05454472452402115|1m15s|88%|
|12|0.1|10|16|2|256+128|0.04783807322382927|1m20s|88%|
|13|0.1|10|16|3|256+128+64|0.43541911244392395|1m24s|86%|
|14|0.1|10|16|4|256+128+64+32|1.7396743297576904|1m45s|40%|
|15|0.1|10|16|2|128+64|0.009450928308069706|1m17s|88%|
|16|0.1|15|16|2|128+64|0.027340466156601906|1m54s|88%|
|17|0.5|15|16|2|128+64|0.0005274285795167089|2m5s|89%|