In [1]:
import os
import struct
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression

In [2]:
# MNISTのファイル (あらかじめダウンロードしておく)
train_image_file = 'mnist/train-images-idx3-ubyte'
train_label_file = 'mnist/train-labels-idx1-ubyte'
test_image_file = 'mnist/t10k-images-idx3-ubyte'
test_label_file = 'mnist/t10k-labels-idx1-ubyte'

In [3]:
def load_images(filename):
    """ MNISTの画像データを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2051:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_images, height, width = struct.unpack('>iii', fp.read(4 * 3))
    
    # 画像の読み込み
    total_pixels = n_images * height * width
    images = struct.unpack('>' + 'B' * total_pixels, fp.read(total_pixels))
    
    images = np.asarray(images, dtype='uint8')
    images = images.reshape((n_images, height, width, 1))
    
    # 値の範囲を[0, 1]に変更する
    images = images.astype('float32') / 255.0
    
    fp.close()
    
    return images

In [4]:
def load_labels(filename):
    """ MNISTのラベルデータを読み込む """

    fp = open(filename, 'rb')
    
    # マジックナンバー
    magic = struct.unpack('>i', fp.read(4))[0]
    if magic != 2049:
        raise Exception('Invalid MNIST file!')
        
    # 各種サイズ
    n_labels = struct.unpack('>i', fp.read(4))[0]
    
    # ラベルの読み込み
    labels = struct.unpack('>' + 'B' * n_labels, fp.read(n_labels))
    labels = np.asarray(labels, dtype='int32')
    
    fp.close()
    
    return labels

In [5]:
def to_onehot(labels):
    """ one-hot形式への変換 """
    return np.identity(10)[labels]

In [6]:
images = load_images(train_image_file)
labels = load_labels(train_label_file)
onehot = to_onehot(labels)

In [31]:
num_data = len(images)
X = images.reshape((num_data, -1))
y = onehot.reshape((num_data, -1))

In [33]:
def softmax(x, axis=-1):
    ex = np.exp(x)
    return ex / np.sum(ex, axis=axis, keepdims=True)

In [85]:
def dLdy(y_pred, y_real):
    return -y_real / y_pred

In [70]:
def dydt(y, t):
    return np.diag(y) - np.einsum('i,j->ij', y, y)

In [89]:
def dtdA(A, x):
    rows, cols = A.shape
    delta = np.identity(rows)
    return np.einsum('ik,j->ijk', delta, x)

In [97]:
def dtdb(b):
    dims = b.shape[-1]
    return np.ones((dims, dims))

In [107]:
batch_size = 128

in_features = X.shape[-1]
out_features = y.shape[-1]
AA = np.random.normal(size=(out_features, in_features))
bb = np.random.normal(size=(out_features))

for epoch in range(4):
    indices = np.random.permutation(np.arange(num_data))
    for b in range(0, num_data, batch_size):
        if b + batch_size > num_data:
            break
        
        xs = X[indices[b:b+batch_size], :]
        ys = y[indices[b:b+batch_size], :]

        loss = 0.0
        acc = 0.0
        grad_AA = np.zeros_like(AA)
        grad_bb = np.zeros_like(bb)
        
        for x, y_real in zip(xs, ys):
            t = np.dot(AA, x) + bb
            y_pred = softmax(t)
            L = np.sum(-y_real * np.log(y_pred))
            
            d1 = dLdy(y_pred, y_real)
            d2 = dydt(y_pred, t)
            d3 = dtdA(AA, x)
            d4 = dtdb(bb)
            grad = np.dot(d2, d1)
            ga = np.dot(d3, grad)
            gb = np.dot(d4, grad)
            
            y_pred_id = np.argmax(y_pred)
            y_real_id = np.argmax(y_real)
            acc += 1.0 if y_pred_id == y_real_id else 0.0
            
            loss += L
            grad_AA += ga
            grad_bb += gb
            
        loss /= batch_size
        acc /= batch_size
        grad_AA /= batch_size
        grad_bb /= batch_size
        
        AA -= 0.05 * grad_AA
        bb -= 0.05 * grad_bb
        
        print('loss={:.6f}, acc={:.6f}'.format(loss, acc))

loss=13.465839, acc=0.070312
loss=13.312373, acc=0.117188
loss=13.544442, acc=0.101562
loss=12.650878, acc=0.140625
loss=13.582631, acc=0.031250
loss=12.323411, acc=0.078125
loss=12.020447, acc=0.117188
loss=12.504057, acc=0.093750
loss=11.457681, acc=0.085938
loss=12.023361, acc=0.054688
loss=12.212145, acc=0.046875
loss=11.712150, acc=0.085938
loss=10.017981, acc=0.078125
loss=11.960637, acc=0.125000
loss=10.926295, acc=0.117188
loss=10.184981, acc=0.070312
loss=10.388682, acc=0.125000
loss=9.932249, acc=0.171875
loss=10.174155, acc=0.085938
loss=10.324137, acc=0.070312
loss=10.665778, acc=0.109375
loss=10.360761, acc=0.109375
loss=10.049808, acc=0.093750
loss=10.508294, acc=0.078125
loss=9.006433, acc=0.109375
loss=9.640614, acc=0.109375
loss=10.322716, acc=0.078125
loss=8.848739, acc=0.101562
loss=9.977149, acc=0.078125
loss=9.355763, acc=0.148438
loss=8.637555, acc=0.148438
loss=8.902130, acc=0.117188
loss=8.158489, acc=0.140625
loss=7.547782, acc=0.156250
loss=8.374686, acc=0.109