In [1]:
# MNIST Handwritten Digits Classification
# Standard feedfoward neural network

In [2]:
import numpy as np
import pandas as pd
from scipy import linalg
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt

np.set_printoptions(threshold=100)

In [3]:
# データの読み込み（MNIST手書き文字）
train_df = pd.read_csv('data/mnist-in-csv/mnist_train.csv', sep=',')
test_df = pd.read_csv('data/mnist-in-csv/mnist_test.csv', sep=',')

train_data = train_df.iloc[:,1:].to_numpy(dtype='float')
train_target = train_df.iloc[:,0].to_numpy(dtype='int')
train_target = OneHotEncoder(sparse=False).fit_transform(train_target.reshape(-1, 1))

test_data = test_df.iloc[:,1:].to_numpy(dtype='float')
test_target = test_df.iloc[:,0].to_numpy(dtype='int')
test_target = OneHotEncoder(sparse=False).fit_transform(test_target.reshape(-1, 1))

In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.
In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.


In [4]:
# 重要パラメータの定義
# データの性質関連
TRAIN_SIZE = 2000        # 訓練データ数
TEST_SIZE = 2000         # テストデータ数
NUM_INPUT_NODES = 784    # 入力の次元数
NUM_OUTPUT_NODES = 10    # 出力の次元数

LEARNING_RATE = 0.01

In [5]:
train_data = train_data[:TRAIN_SIZE]
test_data = test_data[:TEST_SIZE]
train_target = train_target[:TRAIN_SIZE]
test_target = test_target[:TEST_SIZE]

In [6]:
train_data.shape

(2000, 784)

In [7]:
train_target.shape

(2000, 10)

In [8]:
test_data.shape

(2000, 784)

In [9]:
test_target.shape

(2000, 10)

In [10]:
class ReLU:
    def forward(self, u): 
        self.u = u
        return np.maximum(0, u)
    
    def backward(self, dout):
        return dout * (self.u > 0).astype(float)
    
    def update_weights(self, lr=0.1):
        pass


class LinearLayer:
    def __init__(self, I, O):
        self.I = I
        self.O = O
        self.W = np.random.randn(I,O) / np.sqrt(I)   # 重みの初期化
        self.b = np.zeros(O)
        self.grad_W = np.zeros((I,O))
        self.grad_b = np.zeros(O)
        
    def forward(self, x):
        self.x = x
        u = x @ self.W + self.b
        return u
    
    def backward(self, dout):   # dout shape: (O)
        din = self.W @ dout   # shape: (I)
        self.grad_W = self.x.reshape(self.I,1) @ dout.reshape(1, self.O)   # shape: (I, O)
        self.grad_b = dout    # shape: (O)
        return din
    
    def update_weights(self, lr=0.1):
        self.W = self.W - lr * self.grad_W
        self.b = self.b - lr * self.grad_b
        
    
class Softmax_CrossEntropy:   # uとtの次元数が一致する前提
    def forward(self, u):
        self.y = np.exp(u)
        self.y = self.y / np.sum(self.y)
        return self.y
    
    def calculate_error(self, t):
        self.t = t
        error = -np.sum(t * np.log(self.y))
        return error
    
    def backward(self, dout=1.0):
        return dout*(self.y - self.t)
    
    def update_weights(self, lr=0.1):
        pass
        
        

In [11]:
# データの正規化
def normalize(data):
    mean = np.mean(data, axis=1).reshape(-1,1)
    var = np.var(data, axis=1).reshape(-1,1)
    return (data - mean) / (np.sqrt(var) + 1e-6)

train_data = normalize(train_data)
test_data = normalize(test_data)


In [12]:
# ネットワークの定義

l1 = LinearLayer(784, 100)
f1 = ReLU()
l2 = LinearLayer(100, 10)
out = Softmax_CrossEntropy()

layers = [l1, f1, l2, out]


# 訓練

for epoch in range(1000):
    error = 0
    
    for i in range(TRAIN_SIZE):
        x = train_data[i]
        t = train_target[i]
        dout = 1.0

        for layer in layers:
            x = layer.forward(x)
        
        error += layers[-1].calculate_error(t)
        
        for layer in reversed(layers):
            dout = layer.backward(dout)

        for layer in layers:
            layer.update_weights(LEARNING_RATE)
    
    if epoch % 1 == 0:
        print("Epoch no. {}, error is {}".format(epoch, error))
    

Epoch no. 0, error is 1180.5444771623847
Epoch no. 1, error is 483.2379750210976
Epoch no. 2, error is 198.40040775992495
Epoch no. 3, error is 71.37271489491705
Epoch no. 4, error is 31.82124579749911
Epoch no. 5, error is 12.972716444845748
Epoch no. 6, error is 7.525744694883897
Epoch no. 7, error is 5.571671694762204
Epoch no. 8, error is 4.594994490480153
Epoch no. 9, error is 3.926072285855946
Epoch no. 10, error is 3.4469191890916586
Epoch no. 11, error is 3.079100338475735
Epoch no. 12, error is 2.7742183419617943
Epoch no. 13, error is 2.537750453491975
Epoch no. 14, error is 2.330996406786321
Epoch no. 15, error is 2.1591007067906296
Epoch no. 16, error is 2.0121495037318984
Epoch no. 17, error is 1.8816156082703877
Epoch no. 18, error is 1.7667968969416161
Epoch no. 19, error is 1.6649882247578653
Epoch no. 20, error is 1.5787387658631105
Epoch no. 21, error is 1.4947002112037702
Epoch no. 22, error is 1.4234883261952542
Epoch no. 23, error is 1.3560156443275246
Epoch no. 24

Epoch no. 192, error is 0.14338778006317715
Epoch no. 193, error is 0.14257443595858318
Epoch no. 194, error is 0.14180142933499568
Epoch no. 195, error is 0.14100940085138286
Epoch no. 196, error is 0.1402557030034522
Epoch no. 197, error is 0.13948430016346577
Epoch no. 198, error is 0.13872318507094092
Epoch no. 199, error is 0.1379798732635702
Epoch no. 200, error is 0.13723288356920402
Epoch no. 201, error is 0.13651928787855042
Epoch no. 202, error is 0.1357806877645557
Epoch no. 203, error is 0.13507972545171015
Epoch no. 204, error is 0.13435358471031866
Epoch no. 205, error is 0.13365193740467918
Epoch no. 206, error is 0.13297255456074827
Epoch no. 207, error is 0.13225710004616095
Epoch no. 208, error is 0.13158966402288305
Epoch no. 209, error is 0.1309212613944314
Epoch no. 210, error is 0.130238786609704
Epoch no. 211, error is 0.12957947311524978
Epoch no. 212, error is 0.12892172170575994
Epoch no. 213, error is 0.12827509911880544
Epoch no. 214, error is 0.127637659123

Epoch no. 379, error is 0.06892960879602926
Epoch no. 380, error is 0.06873720117794836
Epoch no. 381, error is 0.06853709510633156
Epoch no. 382, error is 0.06834500396305437
Epoch no. 383, error is 0.06815335189470481
Epoch no. 384, error is 0.06796317470247579
Epoch no. 385, error is 0.06777036015764008
Epoch no. 386, error is 0.06758036868164104
Epoch no. 387, error is 0.06738986765841098
Epoch no. 388, error is 0.06720630626664023
Epoch no. 389, error is 0.06701550225408721
Epoch no. 390, error is 0.06683212920175698
Epoch no. 391, error is 0.0666456795206498
Epoch no. 392, error is 0.0664631184459903
Epoch no. 393, error is 0.06627816519319796
Epoch no. 394, error is 0.06609856369642904
Epoch no. 395, error is 0.06591741595867434
Epoch no. 396, error is 0.06573860740596606
Epoch no. 397, error is 0.06555608328361312
Epoch no. 398, error is 0.06538014362965555
Epoch no. 399, error is 0.06520293505242367
Epoch no. 400, error is 0.06502503736763623
Epoch no. 401, error is 0.06485233

Epoch no. 565, error is 0.04474865420720332
Epoch no. 566, error is 0.04466216844303094
Epoch no. 567, error is 0.04457794940494632
Epoch no. 568, error is 0.044493220399421385
Epoch no. 569, error is 0.04440598215547088
Epoch no. 570, error is 0.04432301442809112
Epoch no. 571, error is 0.044240179578934204
Epoch no. 572, error is 0.044155397455621095
Epoch no. 573, error is 0.04407242041766962
Epoch no. 574, error is 0.04398903999536896
Epoch no. 575, error is 0.04390503589269975
Epoch no. 576, error is 0.043824662052212096
Epoch no. 577, error is 0.043740944928920604
Epoch no. 578, error is 0.043659805367433635
Epoch no. 579, error is 0.043578185785021525
Epoch no. 580, error is 0.043495057914985646
Epoch no. 581, error is 0.043415842133774314
Epoch no. 582, error is 0.04333578033116618
Epoch no. 583, error is 0.043254304903910606
Epoch no. 584, error is 0.043174452649605906
Epoch no. 585, error is 0.04309335601841459
Epoch no. 586, error is 0.04301580079760867
Epoch no. 587, error 

Epoch no. 751, error is 0.03287746710393465
Epoch no. 752, error is 0.032829091797531365
Epoch no. 753, error is 0.03278312491983129
Epoch no. 754, error is 0.03273572487429666
Epoch no. 755, error is 0.03268737151332859
Epoch no. 756, error is 0.03264142274582041
Epoch no. 757, error is 0.03259424542779631
Epoch no. 758, error is 0.03254893188114306
Epoch no. 759, error is 0.03250118117408805
Epoch no. 760, error is 0.032455189565458696
Epoch no. 761, error is 0.03240872527293513
Epoch no. 762, error is 0.03236288135474846
Epoch no. 763, error is 0.032316419728481115
Epoch no. 764, error is 0.03227091864407442
Epoch no. 765, error is 0.032225249896536474
Epoch no. 766, error is 0.03217915764321485
Epoch no. 767, error is 0.032133701217457976
Epoch no. 768, error is 0.03208833542564205
Epoch no. 769, error is 0.032042756790069175
Epoch no. 770, error is 0.03199832218312527
Epoch no. 771, error is 0.03195349658755716
Epoch no. 772, error is 0.03190695412467727
Epoch no. 773, error is 0.

Epoch no. 935, error is 0.02591349733917673
Epoch no. 936, error is 0.02588282944851183
Epoch no. 937, error is 0.02585337310085296
Epoch no. 938, error is 0.025822379331160966
Epoch no. 939, error is 0.025793930824711358
Epoch no. 940, error is 0.025763426782838793
Epoch no. 941, error is 0.025733188092204245
Epoch no. 942, error is 0.025704478151141637
Epoch no. 943, error is 0.025674229811957636
Epoch no. 944, error is 0.02564450030072799
Epoch no. 945, error is 0.025615809258489212
Epoch no. 946, error is 0.025585826922360583
Epoch no. 947, error is 0.025556639927077327
Epoch no. 948, error is 0.02552756549702419
Epoch no. 949, error is 0.025497697958143536
Epoch no. 950, error is 0.025469172703714727
Epoch no. 951, error is 0.02544004518375947
Epoch no. 952, error is 0.025410669270802964
Epoch no. 953, error is 0.025382266435964303
Epoch no. 954, error is 0.025352870504172474
Epoch no. 955, error is 0.025323964417093628
Epoch no. 956, error is 0.025295466500741354
Epoch no. 957, e

In [15]:
# 検証

correct_number = 0

for i in range(TEST_SIZE):
    x = test_data[i]
    t = test_target[i]

    for layer in layers:
        x = layer.forward(x)

    predict = np.argmax(x)
    correct_value = np.argmax(t)
    
    if predict == correct_value:
        correct_number += 1
        
print("Accuracy: {}".format(correct_number*1.0/TEST_SIZE))

Accuracy: 0.8875
