<a href="https://colab.research.google.com/github/slightperturbation/ml_examples/blob/master/ML_Examples_Training_Callbacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Training Callbacks

This builds on the MNIST Fashion Example to explore using callbacks during training.


In [1]:
import tensorflow as tf
import numpy as np

# Input

Input during training from in-memory numpy arrays drawn from the Tensorflow example dataset fashion_mnist.

In [2]:
(image_train, label_train), (image_test, label_test) = tf.keras.datasets.fashion_mnist.load_data()

# Normalize the pixel data from 8-bit integer representation [0, 255] to the floating point range [0, 1].
image_train = image_train / 255.0
image_test = image_test / 255.0

# Model Definition


In [3]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

In [4]:
model.compile(optimizer=tf.optimizers.Adam(),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['accuracy'])

# Early Stopping

Callbacks are added to the model.fit() call here, but can also be added to model.evaluate() and model.predict().
**bold text**

In [5]:
class CustomCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    ''' Stop when reaching some desired metric.'''
    keys = list(logs.keys())
    print('\nEnd epoch {} of training; got log keys: {}\n'.format(epoch, keys))
    if logs['accuracy'] > .89:
      print('\nStopping training!\n')
      self.model.stop_training = True
callback = CustomCallback()

model.fit(image_train, label_train, epochs=5, callbacks=[callback])

Epoch 1/5
End epoch 0 of training; got log keys: ['loss', 'accuracy']

Epoch 2/5
End epoch 1 of training; got log keys: ['loss', 'accuracy']

Epoch 3/5
End epoch 2 of training; got log keys: ['loss', 'accuracy']

Epoch 4/5
End epoch 3 of training; got log keys: ['loss', 'accuracy']

Epoch 5/5
End epoch 4 of training; got log keys: ['loss', 'accuracy']


Stopping training!



<tensorflow.python.keras.callbacks.History at 0x7f1e2aaf05c0>

# Save Model at Minimum Loss

See [Keras docs example](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/writing_your_own_callbacks.ipynb#scrollTo=_KEUtZgeYOLA) for more.


In [6]:
class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(CustomCallback, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

callback = CustomCallback()

model.fit(image_train, label_train, epochs=100, callbacks=[callback])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 00027: early stopping


<tensorflow.python.keras.callbacks.History at 0x7f1e308acba8>

# Learning Rate Schedule

Although it adds a big space for hparams to tune, it can be useful to adjust the learning rate (generally starting large and dialing it down) as training progresses. 

In [7]:

LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05),
    (6, 0.01),
    (9, 0.005),
    (12, 0.001),
]


def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr
callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=1)

model.fit(image_train, label_train, epochs=15, callbacks=[callback])


Epoch 00000: Learning rate is 0.0010.
Epoch 1/15

Epoch 00001: Learning rate is 0.0010.
Epoch 2/15

Epoch 00002: Learning rate is 0.0010.
Epoch 3/15

Epoch 00003: Learning rate is 0.0500.
Epoch 4/15

Epoch 00004: Learning rate is 0.0500.
Epoch 5/15

Epoch 00005: Learning rate is 0.0500.
Epoch 6/15

Epoch 00006: Learning rate is 0.0100.
Epoch 7/15

Epoch 00007: Learning rate is 0.0100.
Epoch 8/15

Epoch 00008: Learning rate is 0.0100.
Epoch 9/15

Epoch 00009: Learning rate is 0.0050.
Epoch 10/15

Epoch 00010: Learning rate is 0.0050.
Epoch 11/15

Epoch 00011: Learning rate is 0.0050.
Epoch 12/15

Epoch 00012: Learning rate is 0.0010.
Epoch 13/15

Epoch 00013: Learning rate is 0.0010.
Epoch 14/15

Epoch 00014: Learning rate is 0.0010.
Epoch 15/15


<tensorflow.python.keras.callbacks.History at 0x7f1e284a37b8>