In [1]:
import os
import sys
import struct
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [2]:
# MNISTのファイル (あらかじめダウンロードしておく)
train_image_file = 'mnist/train-images-idx3-ubyte'
train_label_file = 'mnist/train-labels-idx1-ubyte'
test_image_file = 'mnist/t10k-images-idx3-ubyte'
test_label_file = 'mnist/t10k-labels-idx1-ubyte'

In [3]:
def load_images(filename):
    """ MNISTの画像データを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2051:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_images, height, width = struct.unpack('>iii', fp.read(4 * 3))
    
    # 画像の読み込み
    total_pixels = n_images * height * width
    images = struct.unpack('>' + 'B' * total_pixels, fp.read(total_pixels))
    
    images = np.asarray(images, dtype='uint8')
    images = images.reshape((n_images, height, width, 1))
    
    # 値の範囲を[0, 1]に変更する
    images = images.astype('float32') / 255.0
    
    fp.close()
    
    return images

In [4]:
def load_labels(filename):
    """ MNISTのラベルデータを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2049:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_labels = struct.unpack('>i', fp.read(4))[0]
    
    # ラベルの読み込み
    labels = struct.unpack('>' + 'B' * n_labels, fp.read(n_labels))
    labels = np.asarray(labels, dtype='int32')
    
    fp.close()
    
    return labels

In [5]:
def to_onehot(labels):
    """ one-hot形式への変換 """
    return np.identity(10)[labels]

In [6]:
images = load_images(train_image_file)
labels = load_labels(train_label_file)
onehot = to_onehot(labels)

In [7]:
num_data = len(images)
X = images.reshape((num_data, -1))
y = onehot.reshape((num_data, -1))

In [8]:
def softmax(x, axis=-1):
    """ softmax関数 """
    ex = np.exp(x)
    return ex / np.sum(ex, axis=axis, keepdims=True)

def log_softmax(x, axis=-1):
    """ log-softmax関数 """
    ex = np.exp(x)
    return x - np.log(np.sum(ex, axis=axis, keepdims=True))

In [9]:
# ミニバッチのサイズ
batch_size = 128

# パラメータの初期化
in_features = X.shape[-1]
out_features = y.shape[-1]
AA = np.random.normal(size=(out_features, in_features))
bb = np.random.normal(size=(out_features))

# エポック
for epoch in range(1):
    # データの順番は偏りをなくすためにランダムシャッフルする
    indices = np.random.permutation(np.arange(num_data))
    for b in range(0, num_data, batch_size):
        if b + batch_size > num_data:
            break
        
        xs = X[indices[b:b+batch_size], :]
        ys = y[indices[b:b+batch_size], :]

        loss = 0.0
        acc = 0.0
        grad_AA = np.zeros_like(AA)
        grad_bb = np.zeros_like(bb)
        
        # バッチ内の各データに対してロス、精度、勾配を求める
        for x, y_real in zip(xs, ys):
            t = np.dot(AA, x) + bb
            y_pred = softmax(t)
            L = np.sum(-y_real * np.log(y_pred))
            
            delta = np.identity(AA.shape[0])
            dLdy = -y_real / y_pred
            dydt = np.einsum('ij,i->ij', delta, y_pred) - np.einsum('i,j->ij', y_pred, y_pred)
            dtdA = np.einsum('ij,k->ijk', delta, x)
            dtdb = np.ones((bb.shape[-1], bb.shape[-1]))
            dLdt = np.dot(dLdy, dydt)
            dLdA = np.dot(dLdt, dtdA)
            dLdb = np.dot(dLdt, dtdb)
            
            y_pred_id = np.argmax(y_pred)
            y_real_id = np.argmax(y_real)
            acc += 1.0 if y_pred_id == y_real_id else 0.0
            
            loss += L
            grad_AA += dLdA
            grad_bb += dLdb
            
        # 平均を取る
        loss /= batch_size
        acc /= batch_size
        grad_AA /= batch_size
        grad_bb /= batch_size

        # 最急降下法による値の更新
        AA -= 0.1 * grad_AA
        bb -= 0.1 * grad_bb
        
        # printの代わりにsys.stdout.writeを使うとcarrige returnが使える
        sys.stdout.write('\repoch:{}, step:{}, loss={:.6f}, acc={:.6f}'.format(epoch, b, loss, acc))
        sys.stdout.flush()

epoch:0, step:59776, loss=1.228238, acc=0.757812