## MNIST の手書き文字の数字を分類する

In [1]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

## 1. 前処理

In [2]:
image_path = "./"
transform = transforms.Compose([transforms.ToTensor()])
mnist_train_dataset = torchvision.datasets.MNIST(
    root=image_path,
    train=True,
    transform=transform,
    download=False,  # switch to True for first time
)
mnist_test_dataset = torchvision.datasets.MNIST(
    root=image_path,
    train=False,
    transform=transform,
    download=False,  # switch to True for first time
)

batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset, batch_size, shuffle=True)

## 2. モデルの構築

In [3]:
import torch.nn as nn

hidden_units = [32, 16]
image_size = mnist_train_dataset[0][0].shape
input_size = image_size[0] * image_size[1] * image_size[2]
all_layers = [nn.Flatten()]
for hidden_unit in hidden_units:
    layer = nn.Linear(input_size, hidden_unit)
    all_layers.append(layer)
    all_layers.append(nn.ReLU())
    input_size = hidden_unit

all_layers.append(nn.Linear(hidden_units[-1], 10))
all_layers.append(nn.Softmax(dim=1))
model = nn.Sequential(*all_layers)
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
  (6): Softmax(dim=1)
)

In [4]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [5]:
torch.manual_seed(1)
num_epochs = 20
for epoch in range(num_epochs):
    accuracy_hist_train = 0
    for x_batch, y_batch in train_dl:
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
        accuracy_hist_train += is_correct.sum()

    accuracy_hist_train /= len(train_dl.dataset)
    print(f"{epoch=} Accuracy: {accuracy_hist_train:.4f}")

epoch=0 Accuracy: 0.7754
epoch=1 Accuracy: 0.9152
epoch=2 Accuracy: 0.9299
epoch=3 Accuracy: 0.9384
epoch=4 Accuracy: 0.9433
epoch=5 Accuracy: 0.9470
epoch=6 Accuracy: 0.9508
epoch=7 Accuracy: 0.9538
epoch=8 Accuracy: 0.9558
epoch=9 Accuracy: 0.9572
epoch=10 Accuracy: 0.9598
epoch=11 Accuracy: 0.9610
epoch=12 Accuracy: 0.9628
epoch=13 Accuracy: 0.9645
epoch=14 Accuracy: 0.9654
epoch=15 Accuracy: 0.9673
epoch=16 Accuracy: 0.9676
epoch=17 Accuracy: 0.9695
epoch=18 Accuracy: 0.9696
epoch=19 Accuracy: 0.9711


In [6]:
pred = model(mnist_test_dataset.data / 255.0)
is_correct = (torch.argmax(pred, dim=1) == mnist_test_dataset.targets).float()
print(f"Test Accuracy: {is_correct.mean():.4f}")

Test Accuracy: 0.9566
