## Saving and Restoring a Model

When using the Sequential API or the Functional API, saving a trained Keras model is as simple as it gets:

In [1]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import pandas as pd

In [2]:
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()

In [3]:
X_valid, X_train = X_train_full[:5000]/255.0, X_train_full[25000:]/255.0
y_valid, y_train = y_train_full[:5000], y_train_full[25000:]

In [4]:
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape = X_train.shape[1:]),
    keras.layers.Dense(300, activation='relu'),
    keras.layers.Dense(100, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

In [5]:
model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=30, validation_data=(X_valid, y_valid))

Train on 35000 samples, validate on 5000 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<tensorflow.python.keras.callbacks.History at 0x14cc92160>

In [6]:
model.save('my_keras_model.h5')

Keras will use the HDF5 format to save both the model's arhitecture (including every layer's hyperparameters) and the values of all the model parameters for every layer (e.g., connection weights and biases). It also saves the optimizer (including its hyperparameters and any state it may have). You will typically have a script that trains a model and saves it, and one or more scripts (or web services) that load the model and use it to make predictions. Loading the model is just as easy.

In [7]:
model = keras.models.load_model('my_keras_model.h5')

This will work when using the Sequential Api or the Fucntional API, but unfortunately not when using model subclassing. You can use save_weights() and load_weights() to at least save and restore the model parameters, but you will need to save and restore everything yourself. 

But what if training lasts several hours? This is quite common, especially when training on large datasets. In this case, you should not only save your model at the end of training , but also save checkpoints at regular intervals during training, to avoid losing everything if your computer crashes. But how can you tell the fit() method to save checkpoints? Use callbacks.

## Using Callbacks

The fit() method accepts a callbacks argument that lets you specify a list of objects that Keras will call at the start and at the end of training , at the start and end of each epoch, and even before and after processing each batch. For example, the ModelCheckpoint callback saves checkpoints of your model at regular intervals during training, by default at the end of each epoch:

In [8]:
#build and compile the model
model_0 = keras.models.Sequential([
    keras.layers.Flatten(input_shape = X_train.shape[1:]),
    keras.layers.Dense(300, activation='relu'),
    keras.layers.Dense(100, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model_0.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_keras_model_0.h5")
history_0 = model_0.fit(X_train, y_train, epochs=10, callbacks=[checkpoint_cb])

Train on 35000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Morove if you use a validation set during training you can use save_best_only = True when creating the ModelCheckpoint. In this case, it will only save your model when it's performance on the validation set is the best so far. This way, you do not need to worry about training for too long and overfitting the training set: simply restore the last model saved after training, and this will be the best model on the validation set. The following code is simple way to implement early stopping:


In [9]:
model_1 = keras.models.Sequential([
    keras.layers.Flatten(input_shape = X_train.shape[1:]),
    keras.layers.Dense(300, activation='relu'),
    keras.layers.Dense(100, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model_1.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
checkpoint_cb = keras.callbacks.ModelCheckpoint('my_keras_model_1.h5', save_best_only=True)
history_1 = model_1.fit(X_train, y_train,epochs=10,validation_data=(X_valid, y_valid),callbacks=[checkpoint_cb])
model = keras.models.load_model('my_keras_model_1.h5') #roll back to best model

Train on 35000 samples, validate on 5000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Another way to implement early stopping is to simply use the EarlyStopping callback. It will interrupt training when it measures no progress on the validation set for a number of epoch (defined by the patience argument), and it will optionally roll back to the best model. You can combine both callbacks to checkpoints of your model (in case your computer crashes) and interrupt training early when there is no more progress (to avoid wasting time and resources):

In [10]:
model_2 = keras.models.Sequential([
    keras.layers.Flatten(input_shape = X_train.shape[1:]),
    keras.layers.Dense(300, activation='relu'),
    keras.layers.Dense(100, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model_2.compile(loss='sparse_categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
checkpoint_cb_2 = keras.callbacks.ModelCheckpoint('my_keras_model_2.h5', save_best_only=True)
early_stopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
history_2 = model_2.fit(X_train, y_train, epochs=100, validation_data=(X_valid, y_valid),
                        callbacks=[checkpoint_cb_2, early_stopping_cb])

Train on 35000 samples, validate on 5000 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100


The number of epochs can be set to a large value since training will stop automatically when there is no more progress. In this case, there is no need to restore the best model saved because the EarlyStopping callback will keep track of the best weights and restore them for you at the end of training.

There is many other callbacks available in the keras.callbacks package (https://keras.io/callbacks/.)

If you need extra control, you can easily write your own custom callbacks. As an example of how to do that, 
the following custom callback will display the ratio between the validation loss and training loss during training(e.g. to detect overfitting):

In [11]:
class PrintValTrainRatioCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        print('\nval/train: {:2f}'.format(logs['val_loss']/logs['loss']))

As you might expect you can implement on_train_begin(), on_train_end(), on_epoch_begin(), on_epoch_end(),on_batc_begin(),on_natch_end(). Callbacks can also be used during evaluation and predictions, should you ever need them on_test_end(), on_test_batch_begin(), or on_test_batch_end() (called by evaluate()), and for prediction you should implement on_predict_begin(), on_predict_end(), on_predict_batch_begin(), or on_predict_batch_end() (called by predict()).

Now let's take a look at one more tool from tf.keras: TensorBoard

## Using TensorBoard for Visualization