<a href="https://colab.research.google.com/github/prashant-bande/Keras-Custom-Callbacks/blob/master/Keras_Custom_Callbacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt

from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

TensorFlow 2.x selected.


In [0]:
# Define the Keras model to add callbacks to
def get_model():
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
  model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
  return model

In [0]:
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

In [0]:
# define a simple custom callback to track the start and end of every batch of data. During those calls, it prints the index of the current batch.

import datetime

class MyCustomCallback(tf.keras.callbacks.Callback):
  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

In [5]:
# Providing a callback to model methods such as tf.keras.Model.fit()

model = get_model()
_ = model.fit(x_train, 
              y_train, 
              batch_size=64, 
              epochs=1, 
              steps_per_epoch=5, 
              verbose=0, 
              callbacks=[MyCustomCallback()])

Training: batch 0 begins at 10:23:54.555681
Training: batch 0 ends at 10:23:54.810132
Training: batch 1 begins at 10:23:54.810573
Training: batch 1 ends at 10:23:54.812571
Training: batch 2 begins at 10:23:54.813202
Training: batch 2 ends at 10:23:54.814715
Training: batch 3 begins at 10:23:54.815062
Training: batch 3 ends at 10:23:54.818405
Training: batch 4 begins at 10:23:54.818670
Training: batch 4 ends at 10:23:54.821177


## Model methods that take callbacks
Users can supply a list of callbacks to the following `tf.keras.Model` methods:
#### [`fit()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit), [`fit_generator()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit_generator)
Trains the model for a fixed number of epochs (iterations over a dataset, or data yielded batch-by-batch by a Python generator).
#### [`evaluate()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate), [`evaluate_generator()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate_generator)
Evaluates the model for given data or data generator. Outputs the loss and metric values from the evaluation.
#### [`predict()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict), [`predict_generator()`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict_generator)
Generates output predictions for the input data or data generator.


In [6]:
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5, callbacks=[MyCustomCallback()])

Evaluating: batch 0 begins at 10:23:54.864761
Evaluating: batch 0 ends at 10:23:54.918712
Evaluating: batch 1 begins at 10:23:54.919302
Evaluating: batch 1 ends at 10:23:54.921205
Evaluating: batch 2 begins at 10:23:54.922316
Evaluating: batch 2 ends at 10:23:54.924346
Evaluating: batch 3 begins at 10:23:54.924687
Evaluating: batch 3 ends at 10:23:54.928068
Evaluating: batch 4 begins at 10:23:54.928314
Evaluating: batch 4 ends at 10:23:54.930340


## Overview of callback methods


### Common methods for training/testing/predicting
For training, testing, and predicting, following methods are provided to be overridden.
#### `on_(train|test|predict)_begin(self, logs=None)`
Called at the beginning of `fit`/`evaluate`/`predict`.
#### `on_(train|test|predict)_end(self, logs=None)`
Called at the end of `fit`/`evaluate`/`predict`.
#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`
Called right before processing a batch during training/testing/predicting. Within this method, `logs` is a dict with `batch` and `size` available keys, representing the current batch number and the size of the batch.
#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`
Called at the end of training/testing/predicting a batch. Within this method, `logs` is a dict containing the stateful metrics result.

### Training specific methods
In addition, for training, following are provided.
#### on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training.
#### on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.


In [0]:
# The logs dict contains the loss value, and all the metrics at the end of a batch or epoch. 
# Example includes the loss and mean absolute error.

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
  def on_train_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))
  def on_test_batch_end(self, batch, logs=None):
    print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))
  def on_epoch_end(self, epoch, logs=None):
    print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

In [8]:
model = get_model()
_ = model.fit(x_train, 
              y_train, 
              batch_size=64, 
              steps_per_epoch=5, 
              epochs=3, 
              verbose=0, 
              callbacks=[LossAndErrorPrintingCallback()])

For batch 0, loss is   25.51.
For batch 1, loss is  951.36.
For batch 2, loss is   17.48.
For batch 3, loss is    7.58.
For batch 4, loss is    9.36.
The average loss for epoch 0 is  202.26 and mean absolute error is    8.25.
For batch 0, loss is    6.33.
For batch 1, loss is    6.94.
For batch 2, loss is    4.42.
For batch 3, loss is    6.11.
For batch 4, loss is    5.11.
The average loss for epoch 1 is    5.78 and mean absolute error is    1.96.
For batch 0, loss is    5.29.
For batch 1, loss is   11.82.
For batch 2, loss is   13.02.
For batch 3, loss is   12.89.
For batch 4, loss is   16.33.
The average loss for epoch 2 is   11.87 and mean absolute error is    2.74.


In [9]:
# Similarly, one can provide callbacks in evaluate() calls
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20, callbacks=[LossAndErrorPrintingCallback()])

For batch 0, loss is   19.43.
For batch 1, loss is   17.19.
For batch 2, loss is   18.79.
For batch 3, loss is   21.12.
For batch 4, loss is   21.87.
For batch 5, loss is   18.93.
For batch 6, loss is   20.69.
For batch 7, loss is   18.88.
For batch 8, loss is   19.97.
For batch 9, loss is   19.36.
For batch 10, loss is   21.17.
For batch 11, loss is   22.01.
For batch 12, loss is   20.44.
For batch 13, loss is   21.44.
For batch 14, loss is   20.79.
For batch 15, loss is   18.24.
For batch 16, loss is   22.60.
For batch 17, loss is   22.56.
For batch 18, loss is   19.94.
For batch 19, loss is   20.48.


## Examples of Keras callback applications

### Early stopping at minimum loss

In [0]:
import numpy as np

class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
  """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

  def __init__(self, patience=0):
    super(EarlyStoppingAtMinLoss, self).__init__()

    self.patience = patience
    self.best_weights = None

  def on_train_begin(self, logs=None):
    self.wait = 0
    self.stopped_epoch = 0
    self.best = np.Inf

  def on_epoch_end(self, epoch, logs=None):
    current = logs.get('loss')
    if np.less(current, self.best):
      self.best = current
      self.wait = 0
      self.best_weights = self.model.get_weights()
    else:
      self.wait += 1
      if self.wait >= self.patience:
        self.stopped_epoch = epoch
        self.model.stop_training = True
        print('Restoring model weights from the end of the best epoch.')
        self.model.set_weights(self.best_weights)

  def on_train_end(self, logs=None):
    if self.stopped_epoch > 0:
      print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))

In [25]:
model = get_model()
_ = model.fit(x_train, 
              y_train, 
              batch_size=64, 
              steps_per_epoch=5, 
              epochs=30, 
              verbose=0, 
              callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss(patience=5)])

For batch 0, loss is   26.54.
For batch 1, loss is  847.82.
For batch 2, loss is   23.19.
For batch 3, loss is   10.24.
For batch 4, loss is   10.92.
The average loss for epoch 0 is  183.74 and mean absolute error is    8.14.
For batch 0, loss is    6.47.
For batch 1, loss is    6.42.
For batch 2, loss is    9.08.
For batch 3, loss is    7.48.
For batch 4, loss is    4.85.
The average loss for epoch 1 is    6.86 and mean absolute error is    2.15.
For batch 0, loss is    6.47.
For batch 1, loss is    6.31.
For batch 2, loss is    6.11.
For batch 3, loss is    5.67.
For batch 4, loss is    3.70.
The average loss for epoch 2 is    5.65 and mean absolute error is    1.93.
For batch 0, loss is    4.73.
For batch 1, loss is    4.19.
For batch 2, loss is    3.57.
For batch 3, loss is    5.10.
For batch 4, loss is    6.08.
The average loss for epoch 3 is    4.73 and mean absolute error is    1.75.
For batch 0, loss is   14.56.
For batch 1, loss is   44.60.
For batch 2, loss is  132.68.
For ba