# ロジスティック回帰 -その1-

In [1]:
import os
import sys
import struct
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# MNISTのファイル (あらかじめダウンロードしておく)
train_image_file = 'mnist/train-images-idx3-ubyte'
train_label_file = 'mnist/train-labels-idx1-ubyte'
test_image_file = 'mnist/t10k-images-idx3-ubyte'
test_label_file = 'mnist/t10k-labels-idx1-ubyte'

## データの読み込み

In [3]:
def load_images(filename):
    """ MNISTの画像データを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2051:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_images, height, width = struct.unpack('>iii', fp.read(4 * 3))
    
    # 画像の読み込み
    total_pixels = n_images * height * width
    images = struct.unpack('>' + 'B' * total_pixels, fp.read(total_pixels))
    
    images = np.asarray(images, dtype='uint8')
    images = images.reshape((n_images, height, width, 1))
    
    # 値の範囲を[0, 1]に変更する
    images = images.astype('float32') / 255.0
    
    fp.close()
    
    return images

In [4]:
def load_labels(filename):
    """ MNISTのラベルデータを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2049:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_labels = struct.unpack('>i', fp.read(4))[0]
    
    # ラベルの読み込み
    labels = struct.unpack('>' + 'B' * n_labels, fp.read(n_labels))
    labels = np.asarray(labels, dtype='int32')
    
    fp.close()
    
    return labels

In [5]:
def to_onehot(labels):
    """ one-hot形式への変換 """
    return np.identity(10, dtype='float32')[labels]

In [6]:
images = load_images(train_image_file)
labels = load_labels(train_label_file)
onehot = to_onehot(labels)

## CNNによる画像認識

In [7]:
class Network(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Network, self).__init__()
        
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2, stride=1),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=5, padding=0, stride=1),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(800, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        num_batch = x.size()[0]
        h = self.net(x)
        h = h.view(num_batch, -1)
        y = self.fc(h)
        
        return torch.log_softmax(y, dim=1)

In [8]:
num_data = len(images)
X = images.copy()
y = to_onehot(labels)

In [9]:
net = Network(1, 10)
optim = torch.optim.RMSprop(net.parameters(), lr=1.0e-1)
criterion = nn.NLLLoss()

In [10]:
num_epochs = 5
batch_size = 32

net.train()

for epoch in range(num_epochs):
    indices = np.random.permutation(np.arange(num_data))
    for b in range(0, num_data, batch_size):
        if b + batch_size > num_data:
            break
            
        X_real = X[indices[b:b+batch_size], :, :, :]
        y_real = y[indices[b:b+batch_size], :]

        for i, img in enumerate(X_real):
            h, w, _ = img.shape
            tx = 0.1 * w * np.random.uniform(-1.0, 1.0)
            ty = 0.1 * h * np.random.uniform(-1.0, 1.0)
            scale = 1.0 + 0.05 * np.random.uniform(-1.0, 1.0)
            M = np.float32([[scale, 0, tx], [0, scale, ty]])
            X_real[i] = np.expand_dims(cv2.warpAffine(img, M, (h, w)), axis=-1)
        
        X_real = np.transpose(X_real, axes=(0, 3, 1, 2))
        X_real = torch.from_numpy(X_real.astype('float32'))
        y_real = torch.from_numpy(y_real.astype('float32'))
        
        net.zero_grad()
        
        y_pred = net(X_real)
        loss = criterion(y_pred, y_real.argmax(dim=1))
        acc = (y_pred.argmax(dim=1) == y_real.argmax(dim=1)).float().mean()
        sys.stdout.write('\repoch: {}, step:{}, loss={:.6f}, acc={:.6f}'.format(epoch, b, loss.item(), acc.item()))
        #sys.stdout.flush()
        
        loss.backward()
        optim.step()

epoch: 4, step:59968, loss=0.011144, acc=1.000000

## テストデータを用いた精度計算

In [11]:
test_images = load_images(test_image_file)
test_labels = load_labels(test_label_file)

In [12]:
num_data = len(test_images)
X = np.transpose(test_images, axes=(0, 3, 1, 2))
y = to_onehot(test_labels)

In [13]:
net.eval()

avg_loss = 0.0
avg_acc = 0.0
count = 0

for b in range(0, num_data, batch_size):
    if b + batch_size > num_data:
        break

    X_real = X[b:b+batch_size, :, :, :]
    y_real = y[b:b+batch_size, :]

    X_real = torch.from_numpy(X_real.astype('float32'))
    y_real = torch.from_numpy(y_real.astype('float32'))

    y_pred = net(X_real)
    loss = criterion(y_pred, y_real.argmax(dim=1))
    acc = (y_pred.argmax(dim=1) == y_real.argmax(dim=1)).float().mean()
    sys.stdout.write('\repoch: {}, step:{}, loss={:.6f}, acc={:.6f}'.format(epoch, b, loss.item(), acc.item()))

    avg_loss += loss.item()
    avg_acc += acc.item()
    count += 1
print('')
    
avg_loss /= count
avg_acc /= count

sys.stdout.write('loss={:.6f}, acc={:.6f}'.format(avg_loss, avg_acc))

epoch: 4, step:9952, loss=0.035342, acc=0.968750
loss=0.064793, acc=0.982272

## モデルの保存

In [14]:
torch.save(net.state_dict(), 'model.pth')