In [1]:
import renom as rm
import numpy as np
import matplotlib.pyplot as plt
from numpy import random
import sys
sys.path.append('../src')
from network import network
from func import Mahalanobis
from sklearn.metrics import confusion_matrix, classification_report
from time import time

In [2]:
#データロードと設定
data = np.load('../data/mnist.npy')
y_train = data[0][0]
x_train = data[0][1].astype('float32')/255.
y_test = data[1][0]
x_test = data[1][1].astype('float32')/255.
x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)

In [3]:
random.seed(10) #擬似乱数シード固定
latent_dim = 10 #潜在変数空間の次元数
epoch = 10 #学習エポック数
batch_size = 256 #バッチサイズ
opt = rm.Adam() #勾配最適化関数
# 以下はネットワーク定義
ae = network((batch_size, 28*28), epoch=epoch, latent_dim=latent_dim)

In [4]:
# 学習 （学習カーブ等はnotebook/result/内に保存、潜在変数次元が3以上は潜在変数空間をスナップショットとらない設定）
ae.train(opt, x_train, x_test, y_train, y_test)

#    1/   10 KL:3.049 ReconE:23.938 ETA:0.0sec 0.02,0.07,0.09            
*    2/   10 KL:4.745 ReconE:15.904 ETA:0.0sec 0.02,0.08,0.10            
-----------------------------------
#    3/   10 KL:5.342 ReconE:14.573 ETA:0.0sec 0.02,0.09,0.12            
*    4/   10 KL:5.337 ReconE:13.665 ETA:0.0sec 0.02,0.09,0.11            
-----------------------------------
#    5/   10 KL:5.547 ReconE:13.242 ETA:0.0sec 0.02,0.09,0.11            
*    6/   10 KL:5.315 ReconE:12.738 ETA:0.0sec 0.02,0.10,0.12            
-----------------------------------
#    7/   10 KL:5.431 ReconE:12.520 ETA:0.0sec 0.02,0.10,0.12            
*    8/   10 KL:5.150 ReconE:12.122 ETA:0.0sec 0.02,0.10,0.12            
-----------------------------------
#    9/   10 KL:5.226 ReconE:11.970 ETA:0.0sec 0.02,0.10,0.12            
*   10/   10 KL:4.923 ReconE:11.614 ETA:0.0sec 0.02,0.10,0.12            
-----------------------------------


In [5]:
#学習データの推論
_, z_train, xz_train = ae.mini_batch(opt, x_train, inference=True)

=59904/60000 KL:4.971 ReconE:11.531 ETA:0.0sec 0.02,0.00,0.02            

In [6]:
#推論した学習データの潜在変数ベクトルとラベルから共分散行列を計算
f = Mahalanobis(z_train, y_train)
#ラベルの外れ値をどの程度許容するか
f.set_th(0.9998)

Computing Dist
 6.198883056640625e-06sec


In [7]:
#テストデータの推論
_, z_test, xz_test = ae.mini_batch(opt, x_test, inference=True)

= 9984/10000 KL:4.998 ReconE:11.284 ETA:0.0sec 0.01,0.00,0.01           

In [8]:
#テストデータの潜在変数ベクトルからクラス分類
process_t = time()
pred = np.argmin(f.predict(z_test), 1)
print('{:.2f}sec'.format(time()-process_t))
print(confusion_matrix(y_test, pred))
print(classification_report(y_test, pred))

7.11sec
[[ 954    0    8    2    0    9    2    1    4    0]
 [   0 1070   13    7    1    2    2    7   32    1]
 [  10    0  970   15    4    7    4    7   15    0]
 [   2    0   18  936    0   22    0    6   21    5]
 [   3    0    8    0  923    4    4    0    3   37]
 [   9    0    3   28    6  829    1    2   11    3]
 [  20    2    7    2    5   20  897    0    5    0]
 [   0    4   39    4    5    5    0  922    5   44]
 [   5    0    9   39    6   18    2    1  884   10]
 [   7    2    3   17   38    7    0   16   11  908]]
             precision    recall  f1-score   support

          0       0.94      0.97      0.96       980
          1       0.99      0.94      0.97      1135
          2       0.90      0.94      0.92      1032
          3       0.89      0.93      0.91      1010
          4       0.93      0.94      0.94       982
          5       0.90      0.93      0.91       892
          6       0.98      0.94      0.96       958
          7       0.96      0.90    

In [9]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier().fit(z_train, y_train.reshape(-1))

In [10]:
process_t = time()
pred = knn.predict(z_test)
print('{:.2f}sec'.format(time()-process_t))
print(confusion_matrix(y_test, pred))
print(classification_report(y_test, pred))

1.92sec
[[ 964    1    1    1    0    1   10    1    1    0]
 [   0 1127    3    2    1    0    1    0    1    0]
 [   6    3  995    6    2    0    3    4   13    0]
 [   1    0   15  949    0   13    0    7   23    2]
 [   1    0    3    0  917    1    8    3    5   44]
 [   5    3    2   17    4  830   10    1   16    4]
 [   5    3    1    0    4    4  936    0    5    0]
 [   2   14   14    1    5    0    0  964    1   27]
 [   4    0    6   26    1   16    1    2  916    2]
 [   3    5    3   12   35    2    1   11   12  925]]
             precision    recall  f1-score   support

          0       0.97      0.98      0.98       980
          1       0.97      0.99      0.98      1135
          2       0.95      0.96      0.96      1032
          3       0.94      0.94      0.94      1010
          4       0.95      0.93      0.94       982
          5       0.96      0.93      0.94       892
          6       0.96      0.98      0.97       958
          7       0.97      0.94    