[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/persimmon-persimmon/mnist-multilabel/blob/master/train_and_validation.ipynb)

# マウントとインポート

In [None]:
from google.colab import drive
import glob
import os
drive.mount("/content/drive/")
try:
    os.chdir("drive/MyDrive/mnist_multilabel/")
except:
    pass

In [41]:
import os
import csv
import random
import csv
import numpy as np
import pandas as pd
from PIL import Image
import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras import backend as K
from tensorflow.keras import regularizers
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import Input,Dense,GlobalAveragePooling2D,Activation,Conv2D,BatchNormalization,AveragePooling2D,Flatten,Add
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.callbacks import History,LearningRateScheduler,ModelCheckpoint,EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam,SGD
from keras.layers.core import Dropout
from sklearn.metrics import confusion_matrix,roc_curve,auc,roc_auc_score,recall_score,precision_score,accuracy_score,f1_score,precision_recall_curve,auc
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import clip_ops,math_ops
import matplotlib.pyplot as plt

# 学習

## **モデル定義**

In [13]:
def get_multilabel_model(n_class,n_trainable_layer=1,input_shape=(224,224,3))->Model:
    """
    マルチラベル分類用のモデルを返す.
    VGG16->GAP->Dense(512)->Dense(n_class)
    :param n_class:予測クラス数
    :param n_trainable_layer:出力層から何層目までを学習可能層にするか
    :param input_shape:入力サイズ
    """
    # 改善前のモデル
    input = Input(input_shape)
    vgg16 = VGG16(include_top=False, weights="imagenet", input_tensor=input)
    x = vgg16.output
    # Global Pooling Layer
    x = GlobalAveragePooling2D()(x)
    x = Dense(512,kernel_regularizer=regularizers.l2(0.05), activation="relu")(x)
    # Prediction Layer
    output = Dense(n_class,kernel_regularizer=regularizers.l2(0.05), activation="sigmoid")(x)
    model = Model(inputs=input,outputs=output)

    # non-trainable layers
    for layer in model.layers:
        layer.trainable = False
    for layer in model.layers[-n_trainable_layer:]:
        layer.trainable = True
    return model

## 損失関数、評価関数の定義

In [16]:
batch_size=32

# metrics functions
def total_acc(y_true, y_pred):
    pred = K.cast(K.greater_equal(y_pred, 0.5), "float")
    flag = K.cast(K.equal(y_true, pred), "float")
    return K.prod(flag, axis=-1)

def binary_acc(y_true, y_pred):
    pred = K.cast(K.greater_equal(y_pred, 0.5), "float")
    flag = K.cast(K.equal(y_true, pred), "float")
    return K.mean(flag, axis=-1)

#precision
def precision(y_true, y_pred):
    true_positives = K.sum(K.cast(K.greater(K.clip(y_true * y_pred, 0, 1), 0.5), "float32"))
    pred_positives = K.sum(K.cast(K.greater(K.clip(y_pred, 0, 1), 0.50), "float32"))
    precision = true_positives / (pred_positives + K.epsilon())
    return precision

#recall
def recall(y_true, y_pred):
    true_positives = K.sum(K.cast(K.greater(K.clip(y_true * y_pred, 0, 1), 0.5), "float32"))
    poss_positives = K.sum(K.cast(K.greater(K.clip(y_true, 0, 1), 0.5), "float32"))
    recall = true_positives / (poss_positives + K.epsilon())
    return recall

# Loss Function
def binary_crossentropy_balance(target, output):
    beta_p = batch_size / (K.epsilon() + K.sum(target,axis=0))
    beta_n = batch_size / (K.epsilon() + batch_size - K.sum(target,axis=0))
    epsilon_ = constant_op.constant(K.epsilon(), output.dtype.base_dtype)
    output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
    # Compute cross entropy from probabilities.
    bce = target * math_ops.log(output + K.epsilon()) * beta_p 
    bce += (1 - target) * math_ops.log(1 - output + K.epsilon()) * beta_n
    return -K.sum(bce, axis=-1)


## データ準備

In [15]:
with open("data.csv") as f:
    ldr=csv.reader(f)
    record=[x for x in ldr]
train_record=[x for x in record if x[1]=="train"]
val_record=[x for x in record if x[1]=="val"]
test_record=[x for x in record if x[1]=="test"]
len(train_record),len(val_record),len(test_record)

(14000, 4000, 2000)

In [27]:
def get_data_generator(data_arg=True):
  """
  data_generatorを返す.
  data_arg=Falseでデータ拡張なしのなしのdata_generator
  """
  if data_arg:
    # 画像前処理にVGG16の前処理とデータ拡張
    datagen = ImageDataGenerator(
        preprocessing_function=tf.keras.applications.vgg16.preprocess_input,
        rotation_range=10,
        zoom_range=0.1,
        width_shift_range=0.1,
        height_shift_range=0.1
        )
  else:
    # 画像前処理にVGG16の前処理
    datagen = ImageDataGenerator(
        preprocessing_function=tf.keras.applications.vgg16.preprocess_input
        )
  return datagen

In [49]:
datagen=get_data_generator(data_arg=True)

# ラベルカラム
label_col=[str(i) for i in range(10)]
# generator用DataFrameのカラム
cols=["filepath","data_type"]+label_col

# train用generator
df_train=pd.DataFrame(train_record,dtype=np.uint8)
df_train.columns=cols
df_train[label_col]=df_train[label_col].astype(np.float32)
gen_train=datagen.flow_from_dataframe(df_train,directory="",x_col="filepath",y_col=label_col,target_size=(224,224),color_mode="rgb",class_mode="raw",batch_size=batch_size)

# validation用generator
df_val=pd.DataFrame(val_record)
df_val.columns=cols
df_val[label_col]=df_val[label_col].astype(np.float32)
gen_val=datagen.flow_from_dataframe(df_val,directory="",x_col="filepath",y_col=label_col,target_size=(224,224),color_mode="rgb",class_mode="raw",batch_size=batch_size)

Found 14000 validated image filenames.
Found 4000 validated image filenames.


## 学習の実行

In [50]:
input_shape=(224,224,3)
model=get_multilabel_model(n_class=10,input_shape=input_shape)
model_checkpoint = ModelCheckpoint(
    filepath=os.path.join("model", "weight.h5"),
    save_best_only=True,monitor="val_loss",mode="min",verbose=1)
Ecall=EarlyStopping(monitor="val_loss",patience=5,restore_best_weights=False)
model.compile(Adam(epsilon=K.epsilon()), loss=binary_crossentropy_balance,metrics=[total_acc,binary_acc,recall,precision])

In [None]:
n_epoch=50
vb_index=1
initial_epoch=0
model.fit(
    x=gen_train,
    steps_per_epoch=len(train_record)//batch_size,
    epochs=n_epoch,
    initial_epoch=initial_epoch,
    validation_data=gen_val,
    validation_steps=len(val_record)//batch_size,
    callbacks=[model_checkpoint,Ecall],
    workers=4,
    use_multiprocessing=False,
    verbose=vb_index
)

### 確認

In [None]:
x,y=gen_train.__getitem__(0)
idx=random.randint(0,31)
print("true",y[idx])
print("pred",model.predict(x[None,idx]))
pil_image = Image.fromarray(np.uint8(x[idx]*256))
pil_image

# 評価

## テストデータの予測CSVを出力

In [None]:
"""
tta=TrueでTTAありの予測
tta=FalseでTTAなしの予測
"""

tta=False
tta_epoch=30 if tta else 1

model=get_multilabel_model(n_class=10,input_shape=input_shape)
model.load_weights(os.path.join("model", "weight.h5"))

# ラベルカラム
label_col=[str(i) for i in range(10)]
# generator用DataFrameのカラム
cols=["filepath","data_type"]+label_col

# test用generator
test_batch_size=100
df_test=pd.DataFrame(test_record)
df_test.columns=cols
df_test[label_col]=df_test[label_col].astype(np.float32)
datagen=get_data_generator(data_arg=tta)
gen_test=datagen.flow_from_dataframe(df_test,directory="",x_col="filepath",y_col=label_col,target_size=(224,224),color_mode="rgb",class_mode="raw",batch_size=test_batch_size)
y_pred=np.zeros((len(df_val),10),dtype=np.float32)
y_true=np.zeros((len(df_val),10),dtype=np.float32)
steps=len(df_test)//test_batch_size
for _ in range(tta_epoch):
  for i in range(steps):
    x,y=gen_test.next()
    y_pred[i*test_batch_size:(i+1)*test_batch_size]+=model.predict(x)
    y_true[i*test_batch_size:(i+1)*test_batch_size]=y
if tta:
  y_pred /= tta_epoch # 予測の平均を取る

df=pd.DataFrame()
y_true=np.array(y_true)
y_pred=np.array(y_pred)
df_true=pd.DataFrame(y_true)
df_true.columns=label_col
df_pred=pd.DataFrame(y_pred)
df_pred.columns=[x+"_pred" for x in label_col]
th=0.5
ary=[]
for col in label_col:
    df[col]=df_true[col]
    df[col+"_pred"]=df_pred[col+"_pred"]
    d={}
    d["target"]=col
    try:
        d["roc_auc"]=roc_auc_score(df[col],df[col+"_pred"])
    except:
        d["roc_auc"]=None
    d["recal"]=recall_score(df[col],np.where(df[col+"_pred"]>th,1,0))
    d["precision"]=precision_score(df[col],np.where(df[col+"_pred"]>th,1,0))
    d["accuracy"]=accuracy_score(df[col],np.where(df[col+"_pred"]>th,1,0))
    ary.append(d)
suffix=f"tta={tta}"
df.to_csv(f"validation/predict_{suffix}.csv",index=False)
pd.DataFrame(ary).to_csv(f"validation/evaluation_{suffix}.csv",index=False)
pd.DataFrame(ary)

## 予測CSVから評価を計算、可視化

In [None]:
def plot_roc_curve(ax,y_true,y_pred,title):
    """
    AxesオブジェクトにROC曲線を描画する
    """
    fpr, tpr, thresholds =roc_curve(y_true, y_pred)
    ax.plot(fpr, tpr,color="C0")
    ax.fill_between(fpr,tpr,0,color="C0",alpha=0.2)
    ax.plot([0, 1], [0, 1], color="green", linestyle="--")
    score=auc(fpr, tpr)
    ax.set_title(f"{title} auc={round(score,4)}",fontname="MS Gothic")
    ax.grid()

def plot_pr_curve(ax,y_true,y_pred,title):
    """
    AxesオブジェクトにPR曲線を描画する
    """
    pr, rc, thresholds = precision_recall_curve(y_true, y_pred)
    ax.plot(rc, pr,color="C0")
    ax.fill_between(rc,pr,0,color="C0",alpha=0.2)
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    score=auc(rc,pr)
    ax.set_title(f"{title} auc={round(score,4)}",fontname="MS Gothic")
    ax.grid()

def plot_predict_dist(ax,y_true,y_pred,title):
    """
    Axesオブジェクトに予測値の分布を描画する
    """
    posi=y_pred[y_true==1]
    nega=y_pred[y_true!=1]
    if len(posi)==0:return
    
    ax.hist(nega,bins=20, color="C0",alpha=0.5,label="nega")
    ax.hist(posi,bins=20, color="C1",alpha=0.5,label="posi")
    ax.set_xlim(0,1)
    y_pred=np.where(y_pred>0.5,1,0)
    recall=recall_score(y_true,y_pred)
    precision=precision_score(y_true,y_pred)
    ax.set_title(f"{col} rc={round(recall,3)} pr={round(precision,3)}")
    ax.legend()

suffix=f"tta={tta}"
df=pd.read_csv(f"validation/predict_{suffix}.csv")

# ROC曲線
fig=plt.figure(figsize=(10,14))
cnt=0
for col in df.columns:
    if "pred" in col:continue
    if col+"_pred" not in df.columns:continue
    r=cnt//3
    c=cnt%3
    cnt+=1
    ax=plt.subplot2grid((5,3),(r,c))
    y_true=np.where(df[col]>0.5,1,0)
    y_pred=df[col+"_pred"].values
    plot_roc_curve(ax,y_true,y_pred,col)
plt.tight_layout()
filename=f"validation/evaluation_roc_{suffix}.jpg"
plt.savefig(filename)
print("SAVE",filename)

# PR曲線
fig=plt.figure(figsize=(10,14))
cnt=0
for col in df.columns:
    if "pred" in col:continue
    if col+"_pred" not in df.columns:continue
    r=cnt//3
    c=cnt%3
    cnt+=1
    ax=plt.subplot2grid((5,3),(r,c))
    y_true=np.where(df[col]>0.5,1,0)
    y_pred=df[col+"_pred"].values
    plot_pr_curve(ax,y_true,y_pred,col)
plt.tight_layout()
filename=f"validation/evaluation_pr_{suffix}.jpg"
plt.savefig(filename)
print("SAVE",filename)


# 予測値の分布
fig=plt.figure(figsize=(10,14))
cnt=0
for col in df.columns:
    if "pred" in col:continue
    if col+"_pred" not in df.columns:continue
    r=cnt//3
    c=cnt%3
    cnt+=1
    ax=plt.subplot2grid((5,3),(r,c))
    y_true=np.where(df[col]>0.5,1,0)
    y_pred=df[col+"_pred"].values
    plot_predict_dist(ax,y_true,y_pred,col)
plt.tight_layout()
filename=f"validation/evaluation_dist_{suffix}.jpg"
plt.savefig(filename)
print("SAVE",filename)