# **回调函数**

**回调函数**控制的是在训练时进行的一系列的操作，引入位置在`keras.callbacks`

**常用回调**
- EarlyStopping: 控制训练早停，避免发生过拟合
- ModelCheckpoint: 设置保存点
- ReduceLROnPlateau: 学习率衰减
- TensorBoard: 保存训练记录

In [1]:
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os, sys, time
import tensorflow as tf
from tensorflow import keras

In [2]:
fashion_mnist = keras.datasets.fashion_mnist
(X_train_all, y_train_all), (X_test, y_test) = fashion_mnist.load_data()
X_valid, X_train = X_train_all[:5000], X_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape)

(55000, 28, 28) (55000,) (5000, 28, 28) (5000,)


In [3]:
from sklearn.preprocessing import StandardScaler
std_scaler = StandardScaler()
std_scaler.fit(X_train.astype(np.float32).reshape(-1, 1))
X_train_scaled = std_scaler.transform(X_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)
X_valid_scaled = std_scaler.transform(X_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28)

In [4]:
# tf.keras.models.Sequential
model = keras.models.Sequential()

# 添加层
model.add(keras.layers.Flatten(input_shape=[28, 28]))
model.add(keras.layers.Dense(300, activation='relu'))
model.add(keras.layers.Dense(100, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))
# Sequential可以接受一个*args参数
# compile
model.compile(loss='sparse_categorical_crossentropy', optimizer='Adam', metrics=['accuracy'])

In [5]:
# Tensorboard EarlyStopping ModelCheckpoint
log_dir = os.path.join('callbacks')
if not os.path.exists(log_dir):
    os.mkdir(log_dir)
output_model = os.path.join(os.path.join(log_dir, 
                                         'fashion_mnist_model.h5'))
callbacks = [
    keras.callbacks.TensorBoard(log_dir),
    keras.callbacks.ModelCheckpoint(output_model, save_best_only=True),
    keras.callbacks.EarlyStopping(min_delta=1e-3, patience=5),
]

In [6]:
history = model.fit(X_train_scaled, y_train, epochs=10, validation_data=(X_valid_scaled, y_valid), callbacks=callbacks)

Train on 55000 samples, validate on 5000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
