![](../image/lenet.svg)
![](../image/lenet-vert.svg)

卷积:
- 2个5x5的卷积层, 3个全连接层

其他:

- 使用古老的sigmoid激活函数(而非ReLU), 平均汇聚层(而非最大汇聚层)

In [1]:
import torch
# torchvision.datasets.FashionMNIST
import torchvision
# 修改数据集格式
from torchvision import transforms
# nn块
from torch import nn

In [2]:
# -----------参数-----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
batch_size = 128
lr = 0.1
num_epochs = 20

cuda


In [3]:
trans = transforms.ToTensor()
mnist_train_totensor = torchvision.datasets.FashionMNIST(
    root="../data",
    train=True,
    download=True,
    transform=trans
)
mnist_test_totensor = torchvision.datasets.FashionMNIST(
    root="../data",
    train=False,
    download=True,
    transform=trans
)

In [4]:
# [C, H, W]
mnist_train_totensor[0][0].shape

torch.Size([1, 28, 28])

In [5]:
# shuffle, 打乱
# num_workers, 使用4个进程来读取数据
train_iter = torch.utils.data.DataLoader(
    mnist_train_totensor, batch_size, shuffle=True, num_workers=4)
test_iter = torch.utils.data.DataLoader(
    mnist_test_totensor, batch_size, shuffle=True, num_workers=4)

In [9]:
# [B, C, H, W]
for batch, (X, y) in enumerate(train_iter):
    print(X.shape)
    break

torch.Size([128, 1, 28, 28])


In [6]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120),
    nn.Sigmoid(),
    nn.Linear(120, 84),
    nn.Sigmoid(),
    nn.Linear(84, 10)
).to(device)
net

Sequential(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): Sigmoid()
  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (4): Sigmoid()
  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=400, out_features=120, bias=True)
  (8): Sigmoid()
  (9): Linear(in_features=120, out_features=84, bias=True)
  (10): Sigmoid()
  (11): Linear(in_features=84, out_features=10, bias=True)
)

In [7]:
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32).to(device)
print(f'input shape:   {X.shape}')
for layer in net:
    X = layer(X)
    print(f'{layer.__class__.__name__: <15}{X.shape}')

input shape:   torch.Size([1, 1, 28, 28])
Conv2d         torch.Size([1, 6, 28, 28])
Sigmoid        torch.Size([1, 6, 28, 28])
AvgPool2d      torch.Size([1, 6, 14, 14])
Conv2d         torch.Size([1, 16, 10, 10])
Sigmoid        torch.Size([1, 16, 10, 10])
AvgPool2d      torch.Size([1, 16, 5, 5])
Flatten        torch.Size([1, 400])
Linear         torch.Size([1, 120])
Sigmoid        torch.Size([1, 120])
Linear         torch.Size([1, 84])
Sigmoid        torch.Size([1, 84])
Linear         torch.Size([1, 10])


In [8]:
def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)


net.apply(init_weights)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()

In [8]:
def train_loop(train_iter, net, loss, optimizer):
    # 共有几批
    num_batchs = len(train_iter)
    # 总平均loss
    total_train_loss = 0
    for batch, (X, y) in enumerate(train_iter):
        # move to device
        X, y = X.to(device), y.to(device)
        # 该批的推断结果
        y_hat = net(X)

        train_loss = loss(y_hat, y)
        total_train_loss += train_loss.item()

        # Backpropagation
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # --------打印进度
        print(f"\r[{batch+1:>8d}/{num_batchs:>8d}]  ", end='')

    return total_train_loss / num_batchs

In [9]:
# ---------训练
for epoch in range(num_epochs):
    total_train_loss = train_loop(train_iter, net, loss, optimizer)
    print(f'epoch {epoch + 1}, total_train_loss {total_train_loss:f}')

[     469/     469]  epoch 1, total_train_loss 2.308835
[     469/     469]  epoch 2, total_train_loss 2.305790
[     469/     469]  epoch 3, total_train_loss 2.294883
[     469/     469]  epoch 4, total_train_loss 1.825610
[     469/     469]  epoch 5, total_train_loss 1.135702
[     469/     469]  epoch 6, total_train_loss 0.955615
[     469/     469]  epoch 7, total_train_loss 0.871286
[     469/     469]  epoch 8, total_train_loss 0.819198
[     469/     469]  epoch 9, total_train_loss 0.772197
[     469/     469]  epoch 10, total_train_loss 0.724309
[     469/     469]  epoch 11, total_train_loss 0.685572
[     469/     469]  epoch 12, total_train_loss 0.655792
[     469/     469]  epoch 13, total_train_loss 0.634837
[     469/     469]  epoch 14, total_train_loss 0.617224
[     469/     469]  epoch 15, total_train_loss 0.603872
[     469/     469]  epoch 16, total_train_loss 0.591482
[     469/     469]  epoch 17, total_train_loss 0.579464
[     469/     469]  epoch 18, total_tra

In [11]:
# ----------预测
def test_net(test_iter, net, loss):
    # 共有几批
    num_batchs = len(test_iter)
    # 总平均loss, 总平均准确率
    total_test_loss, total_correct = 0, 0
    # 设定评估模式
    net.eval()
    # 不要梯度
    with torch.no_grad():
        for batch, (X, y) in enumerate(test_iter):
            # move to device
            X, y = X.to(device), y.to(device)
            y_hat = net(X)

            test_loss = loss(y_hat, y)
            # 分类0,1,2,3的类别对的上否
            correct = (y_hat.argmax(1) == y).float().sum().item()
            total_test_loss += test_loss.item()
            total_correct += correct/len(X)

            # --------打印进度
            print(f"\r[{batch+1:>8d}/{num_batchs:>8d}]  ", end='')

    total_test_loss /= num_batchs
    total_correct /= num_batchs
    print(
        f"\nTest: Accuracy: {total_correct:.1%}, Avg loss: {total_test_loss:f}")


test_net(test_iter, net, loss)


[       1/      79]  
[       2/      79]  
[       3/      79]  
[       4/      79]  
[       5/      79]  
[       6/      79]  
[       7/      79]  
[       8/      79]  
[       9/      79]  
[      10/      79]  
[      11/      79]  
[      12/      79]  
[      13/      79]  
[      14/      79]  
[      15/      79]  
[      16/      79]  
[      17/      79]  
[      18/      79]  
[      19/      79]  
[      20/      79]  
[      21/      79]  
[      22/      79]  
[      23/      79]  
[      24/      79]  
[      25/      79]  
[      26/      79]  
[      27/      79]  
[      28/      79]  
[      29/      79]  
[      30/      79]  
[      31/      79]  
[      32/      79]  
[      33/      79]  
[      34/      79]  
[      35/      79]  
[      36/      79]  
[      37/      79]  
[      38/      79]  
[      39/      79]  
[      40/      79]  
[      41/      79]  
[      42/      79]  
[      43/      79]  
[      44