#オートエンコーダで異常検知をしてみよう！

In [0]:
from keras.models import Model
from keras.layers import Input, Conv2D
from keras.layers import Conv2DTranspose, LeakyReLU
from keras.layers.core import Activation
from keras.optimizers import Adam
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import cifar10, fashion_mnist

def data_choice(x, label):
    normal, anomaly = [], []
    for i in range(len(x)):
        if label[i] == 0:# 半袖
            normal.append(x[i])
        elif label[i] == 2:# 長袖
            anomaly.append(x[i])

    return np.array(normal), np.array(anomaly)

def get_data():
    # dataset
    #(x_train, y_train), (x_test, y_test) = cifar10.load_data()
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

    x_train = np.expand_dims(x_train, axis=-1)
    x_test = np.expand_dims(x_test, axis=-1)

    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255

    x_train_normal, _ = data_choice(x_train, y_train)
    x_test_normal, x_test_anomaly = data_choice(x_test, y_test)

    return x_train_normal, x_test_normal, x_test_anomaly

def plot_fig(fig1, fig2, fig3, anomaly_detection=False):
    plt.figure(figsize=(11,45))
    plt.subplot(1,3,1)
    plt.imshow(fig1, cmap="gray")
    if anomaly_detection:
        plt.title("Original data")
    else:
        plt.title("Train normal data")

    plt.subplot(1,3,2)
    plt.imshow(fig2, cmap="gray")
    if anomaly_detection:
        plt.title("Reconstruction data")
    else:
        plt.title("Test normal data")

    plt.subplot(1,3,3)
    plt.imshow(np.abs(fig3), cmap="gray")
    if anomaly_detection:
        plt.title("difference data(score=%d)" % np.abs(np.sum(fig3)))
    else:
        plt.title("Test anomaly data")
    plt.show()

In [0]:
x_train_normal, x_test_normal, x_test_anomaly = get_data()

In [0]:
plot_fig(x_train_normal[0,:,:,0], x_test_normal[0,:,:,0], x_test_anomaly[0,:,:,0])

##オートエンコーダ

In [0]:
def build_model(x, channel=1):
    # build encoder
    inputs = Input(x.shape[1:])
    x = Conv2D(32, (5, 5), padding='same', strides=(2,2))(inputs)
    x = Activation('relu')(x)#14*14
  
    x = Conv2D(64, (5, 5), padding='same', strides=(2,2))(x)
    x = Activation('relu')(x)#7*7

    # build decoder
    x = Conv2DTranspose(32, (5, 5), padding='same', strides=(2,2))(x)
    x = Activation('relu')(x)#14*14
  
    x = Conv2DTranspose(channel, (5, 5), padding='same', strides=(2,2))(x)
    x = Activation('sigmoid')(x)#28*28
  
    # build ae
    ae =  Model(inputs, x)
    
    return ae

In [0]:
model = build_model(x_train_normal, channel=1)
model.summary()

https://medium.com/apache-mxnet/transposed-convolutions-explained-with-ms-excel-52d13030c7e8

In [0]:
model.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss="mse")

hist = model.fit(x_train_normal, x_train_normal,
                 validation_data=(x_test_normal, x_test_normal),
                 epochs=40,
                 verbose=True,
                 batch_size=128)

In [0]:
test = x_test_normal[2]
predict = model.predict(np.expand_dims(test,axis=0))

plot_fig(test[:,:,0], predict[0,:,:,0], test[:,:,0] - predict[0,:,:,0], True)

In [0]:
test = x_test_anomaly[18]
predict = model.predict(np.expand_dims(test,axis=0))

plot_fig(test[:,:,0], predict[0,:,:,0], test[:,:,0] - predict[0,:,:,0], True)

##異常検知性能の評価

In [0]:
normal_score = np.abs(model.predict(x_test_normal) - x_test_normal)
normal_score = np.sum(normal_score.reshape(len(normal_score),-1),axis=1)
anomaly_score = np.abs(model.predict(x_test_anomaly) - x_test_anomaly)
anomaly_score = np.sum(anomaly_score.reshape(len(anomaly_score),-1),axis=1)

plt.figure(figsize=(12,6))
plt.plot(normal_score, label="normal")
plt.plot(anomaly_score, label="anomaly")
plt.legend()
plt.show()

##AUC

In [0]:
from sklearn import metrics

def get_auc(test_normal, test_anomaly):
    #ROC曲線の描画
    y_true = np.zeros(len(test_normal)+len(test_anomaly))
    y_true[len(test_normal):] = 1#0:正常、1：異常

    # FPR, TPR(, しきい値) を算出
    fpr, tpr, _ = metrics.roc_curve(y_true, np.hstack((test_normal, test_anomaly)))

    # AUC
    auc = metrics.auc(fpr, tpr)
    
    # ROC曲線をプロット
    plt.plot(fpr, tpr, label='metric learning(AUC = %.2f)'%auc)
    plt.legend()
    plt.title('ROC curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.grid(True)
    plt.show()