## 模型保存的意义
- 可以重保存得地方开始训练，避免训练时间过长
- 可以分享模型，借鉴别人练过的模型再创作


In [1]:
import h5py
import yaml

In [2]:
from __future__ import absolute_import,division,print_function
import os 

import tensorflow as tf
from tensorflow import keras

tf.__version__

'1.8.0'

## Load dataset

In [3]:
(train_data,train_labels),(test_data,test_labels) = keras.datasets.mnist.load_data()
train_data.shape

(60000, 28, 28)

In [4]:
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

In [5]:
train_data = train_data[:1000].reshape(-1,28*28)/255
test_data = test_data[:1000].reshape(-1,28*28)/255

In [6]:
train_data.shape

(1000, 784)

## Create model

In [7]:
def create_model():
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(512, activation = tf.nn.relu, input_shape = (784,)))
    model.add(keras.layers.Dropout(0.2))
    model.add(keras.layers.Dense(10, activation = tf.nn.softmax))
    
    model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.sparse_categorical_crossentropy,
              metrics=["accuracy"])
    
    return model

In [8]:
model = create_model()
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5130      
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________


## Checkpoint save

在训练期间通过回调函数保存检查点.
可以通过训练中的中间变量参数命名checkpoint文件，如"model_{epoch:02d}-{val_acc:.2f}.hdf5"

In [9]:
checkpoint_path = "./output/checkpoint-{epoch:04d}.ckpt"
checkpoint_directory = os.path.dirname(checkpoint_path)

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                         save_weights_only=False,
                                                         verbose=1,
                                                         period=5)

model.fit(train_data,
          train_labels,
          epochs=50,
          validation_data=(test_data,test_labels),
          callbacks=[checkpoint_callback],
          verbose=2)



Train on 1000 samples, validate on 1000 samples
Epoch 1/50
 - 3s - loss: 1.1796 - acc: 0.6730 - val_loss: 0.7535 - val_acc: 0.7540
Epoch 2/50
 - 0s - loss: 0.4236 - acc: 0.8770 - val_loss: 0.5326 - val_acc: 0.8300
Epoch 3/50
 - 0s - loss: 0.2876 - acc: 0.9200 - val_loss: 0.5329 - val_acc: 0.8380
Epoch 4/50
 - 0s - loss: 0.2154 - acc: 0.9420 - val_loss: 0.4339 - val_acc: 0.8600
Epoch 5/50
 - 0s - loss: 0.1509 - acc: 0.9680 - val_loss: 0.4195 - val_acc: 0.8610

Epoch 00005: saving model to ./output/checkpoint-0005.ckpt
Epoch 6/50
 - 0s - loss: 0.1205 - acc: 0.9700 - val_loss: 0.4083 - val_acc: 0.8610
Epoch 7/50
 - 0s - loss: 0.0858 - acc: 0.9890 - val_loss: 0.4041 - val_acc: 0.8670
Epoch 8/50
 - 0s - loss: 0.0668 - acc: 0.9920 - val_loss: 0.3959 - val_acc: 0.8680
Epoch 9/50
 - 0s - loss: 0.0540 - acc: 0.9930 - val_loss: 0.4475 - val_acc: 0.8600
Epoch 10/50
 - 0s - loss: 0.0456 - acc: 0.9930 - val_loss: 0.3979 - val_acc: 0.8670

Epoch 00010: saving model to ./output/checkpoint-0010.ckpt
E

<tensorflow.python.keras._impl.keras.callbacks.History at 0x17700730e10>

## 检查点恢复

In [10]:
del model

In [11]:
model = create_model()
loss,acc = model.evaluate(test_data,test_labels)
print("untrained mode, accuracy: {:.2f}%".format(100*acc))

untrained mode, accuracy: 13.80%


In [12]:
#TODO: somethine wrong here,latest_checkpoint return NONE 
ck_dir = os.path.normpath(os.path.realpath(checkpoint_directory))
latest_path = tf.train.latest_checkpoint(checkpoint_dir=ck_dir)
print(latest_path)

None


In [13]:
# latest_path = tf.train.latest_checkpoint(checkpoint_dir=os.path.realpath(checkpoint_directory))
# print(latest_path)
model.load_weights("output/checkpoint-0050.ckpt")
loss,acc = model.evaluate(test_data,test_labels)
print("Restored mode, accuracy: {:.2f}%".format(100*acc))

Restored mode, accuracy: 87.20%
