## Tensorflow 在训练期间使用checkpoint保存模型

使用程序检查点（checkpoint）技术，可以在训练过程中保存模型：
* 训练程序崩溃，也不需要从头开始训练，加载一个检查点模型继续开始训练
* 预估服务可以加载一个检查点模型实现模型更新

In [1]:
import tensorflow as tf
from tensorflow import keras

### 1. 读取数据构建模型

#### 读取数据

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:10000]
test_labels = test_labels[:10000]

train_images = train_images[:10000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:10000].reshape(-1, 28 * 28) / 255.0

In [3]:
# 这是个10分类的训练任务
train_labels[:5]

array([5, 0, 4, 1, 9], dtype=uint8)

#### 定义简单模型

In [4]:
# 定义一个简单的序列模型
def create_model():
    model = tf.keras.models.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='rmsprop',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    return model

### 2. 在训练期间保存模型

tf.keras.callbacks.ModelCheckpoint 允许在训练的过程中和结束时回调保存的模型。

In [5]:
# 创建一个基本的模型实例
model = create_model()

In [7]:
# 创建一个保存模型权重的回调
checkpoint_path = "./traing_ckpt/cp_{epoch:02d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

In [8]:
# 训练模型，训练过程中每个epoch会保存checkpoint
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])

Epoch 1/10
Epoch 00001: saving model to ./traing_ckpt/cp_01.ckpt
Epoch 2/10
Epoch 00002: saving model to ./traing_ckpt/cp_02.ckpt
Epoch 3/10
Epoch 00003: saving model to ./traing_ckpt/cp_03.ckpt
Epoch 4/10
Epoch 00004: saving model to ./traing_ckpt/cp_04.ckpt
Epoch 5/10
Epoch 00005: saving model to ./traing_ckpt/cp_05.ckpt
Epoch 6/10
Epoch 00006: saving model to ./traing_ckpt/cp_06.ckpt
Epoch 7/10
Epoch 00007: saving model to ./traing_ckpt/cp_07.ckpt
Epoch 8/10
Epoch 00008: saving model to ./traing_ckpt/cp_08.ckpt
Epoch 9/10
Epoch 00009: saving model to ./traing_ckpt/cp_09.ckpt
Epoch 10/10
Epoch 00010: saving model to ./traing_ckpt/cp_10.ckpt


<tensorflow.python.keras.callbacks.History at 0x7fd8c5fe8490>

In [10]:
model.evaluate(test_images,  test_labels, verbose=2)

313/313 - 1s - loss: 0.1980 - accuracy: 0.9564


[0.1979878842830658, 0.9563999772071838]

### 3. 使用checkpoint文件

In [11]:
# 创建一个新model
new_model = create_model()

#### 加载模型预估

In [12]:
# 加载权重，不需要写.index等后缀
new_model.load_weights("./traing_ckpt/cp_09.ckpt")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd8c5ed47d0>

In [13]:
# 直接进入评估
new_model.evaluate(test_images,  test_labels, verbose=2)

313/313 - 0s - loss: 0.1720 - accuracy: 0.9612


[0.17197643220424652, 0.9611999988555908]

In [14]:
# 直接进行预估
new_model.predict(test_images[:3])

array([[1.8646621e-11, 4.0034703e-15, 6.1754810e-09, 1.1134078e-05,
        1.9557399e-16, 2.0905860e-11, 7.6235390e-19, 9.9998891e-01,
        6.6065271e-11, 2.8822367e-09],
       [1.8763623e-12, 2.7294575e-11, 1.0000000e+00, 5.7642353e-08,
        1.0360668e-21, 3.5209656e-12, 2.5003152e-10, 1.4259344e-16,
        2.9851371e-10, 8.9087966e-21],
       [4.0501280e-10, 9.9966383e-01, 1.7066914e-04, 2.2081480e-05,
        3.0094197e-05, 2.1883482e-08, 1.5438196e-06, 2.7922259e-05,
        8.3738938e-05, 8.2594489e-08]], dtype=float32)

#### 继续开始训练

In [15]:
# 使用新的回调训练模型，loss会接着当前状态继续训练
new_model.fit(train_images, 
          train_labels,  
          epochs=3,
          validation_data=(test_images,test_labels))

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fd8bc0a1490>