# CNNによる手書き文字認識

In [3]:
import os
import sys
import struct
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
# MNISTのファイル (あらかじめダウンロードしておく)
train_image_file = 'mnist/train-images-idx3-ubyte'
train_label_file = 'mnist/train-labels-idx1-ubyte'
test_image_file = 'mnist/t10k-images-idx3-ubyte'
test_label_file = 'mnist/t10k-labels-idx1-ubyte'

## データの読み込み

In [5]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, mode='train'):
        super(MnistDataset, self).__init__()

        self.root_dir = root_dir
        self.mode = mode

        if self.mode == 'train':
            self.image_file = 'train-images-idx3-ubyte'
            self.label_file = 'train-labels-idx1-ubyte'
        elif self.mode == 'test':
            self.image_file = 't10k-images-idx3-ubyte'
            self.label_file = 't10k-labels-idx1-ubyte'
        else:
            raise Exception('MNIST dataset mode must be "train" or "test"')
        
        self.image_data = self._load_images(os.path.join(self.root_dir, self.image_file))
        self.label_data = self._load_labels(os.path.join(self.root_dir, self.label_file))

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        return {
            'images': self.image_data[idx],
            'labels': self.label_data[idx]
        }

    def _load_images(self, filename):
        with open(filename, 'rb') as fp:
            magic = struct.unpack('>i', fp.read(4))[0]
            if magic != 2051:
                raise Exception('Magic number does not match!')

            n_images, height, width = struct.unpack('>iii', fp.read(4 * 3))

            n_pixels = n_images * height * width
            pixels = struct.unpack('>' + 'B' * n_pixels, fp.read(n_pixels))
            pixels = np.asarray(pixels, dtype='uint8').reshape((n_images, 1, height, width))

            # 画像サイズを2べきにしておく
            pixels = np.pad(pixels, [(0, 0), (0, 0), (2, 2), (2, 2)], mode='constant', constant_values=0)
            pixels = (pixels / 255.0).astype('float32')

        return pixels

    def _load_labels(self, filename):
        with open(filename, 'rb') as fp:
            magic = struct.unpack('>i', fp.read(4))[0]
            if magic != 2049:
                raise Exception('Magic number does not match!')

            n_labels = struct.unpack('>i', fp.read(4))[0]
            labels = struct.unpack('>' + 'B' * n_labels, fp.read(n_labels))
            labels = np.asarray(labels, dtype='int64')

        return labels

In [6]:
def to_onehot(labels):
    """ one-hot形式への変換 """
    return np.identity(10, dtype='float32')[labels]

## CNNによる画像認識

In [24]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        net0 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Sigmoid(),
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Sigmoid()
        )

        net1 = nn.Sequential(
            nn.Linear(5 * 5 * 16, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )
        
        self.nets = nn.ModuleList([net0, net1])

    def forward(self, x):
        n_batches, _, _, _ = x.size()
        x = self.nets[0](x)
        x = x.view(n_batches, -1)
        y = self.nets[1](x)
        return torch.log_softmax(y, dim=1)

In [26]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=0, stride=1),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=5, padding=0, stride=1),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(800, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10)
        )
        
    def forward(self, x):
        num_batch = x.size()[0]
        h = self.net(x)
        h = h.view(num_batch, -1)
        y = self.fc(h)
        
        return torch.log_softmax(y, dim=1)

In [35]:
# パラメータ
num_epochs = 1
num_batch = 128

In [36]:
# データセットの用意
train_dataset = MnistDataset(root_dir='./mnist', mode='train')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=num_batch, shuffle=True, drop_last=True)

In [37]:
# デバイス
if torch.cuda.is_available():
    device = torch.device('cuda', 0)
else:
    device = torch.device('cpu')
print(device)

cpu


In [42]:
# ネットワークの構築
#net = LeNet5()
net = Network()
net.to(device)
optim = torch.optim.Adam(net.parameters(), lr=1.0e-3)
criterion = nn.NLLLoss()

In [43]:
steps = 0
for epoch in range(num_epochs):
    status_bar = tqdm(train_loader, file=sys.stdout)
    for data in status_bar:
            
        X_real = data['images'].to(device)
        y_real = data['labels'].to(device)
                
        net.train()
        
        y_pred = net(X_real)
        loss = criterion(y_pred, y_real)
        acc = (y_pred.argmax(dim=1) == y_real).float().mean()
        status_bar.set_description('epoch: {}, step:{}, loss={:.6f}, acc={:.6f}'.format(epoch, steps, loss.item(), acc.item()))
        
        net.zero_grad()
        loss.backward()
        optim.step()
        
        steps += 1

epoch: 0, step:467, loss=0.028063, acc=0.992188: 100%|██████████| 468/468 [00:29<00:00, 16.03it/s]


## テストデータを用いた精度計算

In [44]:
test_dataset = MnistDataset(root_dir='./mnist', mode='test')
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=num_batch, shuffle=False, drop_last=False)

In [45]:
avg_loss = 0.0
avg_acc = 0.0
count = 0

for data in tqdm(train_loader, file=sys.stdout):

    X_real = data['images'].to(device)
    y_real = data['labels'].to(device)

    net.eval()

    y_pred = net(X_real)
    loss = criterion(y_pred, y_real)
    acc = (y_pred.argmax(dim=1) == y_real).float().mean()

    avg_loss += loss.item()
    avg_acc += acc.item()
    count += 1
    
avg_loss /= count
avg_acc /= count

sys.stdout.write('loss={:.6f}, acc={:.6f}'.format(avg_loss, avg_acc))

100%|██████████| 468/468 [00:07<00:00, 64.31it/s]
loss=0.031780, acc=0.991019

## モデルの保存

In [None]:
torch.save(net.state_dict(), 'model.pth')