# Keras Callbacks

- Callback is a Python class meant to be subclassed to provide specific functionality, with a set of methods called at various stages of training (including batch/epoch start and ends), testing, and predicting
- Callbacks are useful to get a view on internal states and statistics of the model during training. 
- The methods of the callbacks can be called at different stages of training/evaluating/inference.
- Keras's available Callbacks is [here](https://keras.io/api/callbacks/). We will see some of them in action

# Import Libraries

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals


import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import io
from PIL import Image
from IPython.display import Image as IPyImage
import imageio

from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, LearningRateScheduler, ModelCheckpoint, CSVLogger, ReduceLROnPlateau
#%load_ext tensorboard

import os
import matplotlib.pylab as plt
import numpy as np
import math
import datetime
import pandas as pd

print("Version: ", tf.__version__)
tf.get_logger().setLevel('INFO')

In [None]:
# Download and prepare the horses or humans dataset from tensorflow datasets

splits, info = tfds.load('horses_or_humans', as_supervised=True, with_info=True, split=['train[:80%]', 'train[80%:]', 'test'])

(train_examples, validation_examples, test_examples) = splits

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

## Define Some Variables

In [None]:
BATCH_SIZE = 32
IMAGE_SIZE = (150, 150)

In [None]:
# Image format function
def format_image(image, label):
  image = tf.image.resize(image, IMAGE_SIZE) / 255.0
  return  image, label

In [None]:
# Let's define train,validation,test batches
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

In [None]:
# Let's see the shape of one batch
for image_batch, label_batch in train_batches.take(1):
  pass
image_batch.shape

In [None]:
# Let's create a function to create the simple models
def build_model(dense_units, input_shape=IMAGE_SIZE + (3,)):
  model = tf.keras.models.Sequential([
          tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=input_shape),
          tf.keras.layers.MaxPooling2D(2, 2),
          tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
          tf.keras.layers.MaxPooling2D(2, 2),
          tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
          tf.keras.layers.MaxPooling2D(2, 2),
          tf.keras.layers.Flatten(),
          tf.keras.layers.Dense(dense_units, activation='relu'),
          tf.keras.layers.Dense(2, activation='softmax')
  ])
  return model

## [Model Checkpoint](https://keras.io/api/callbacks/model_checkpoint/)

- Models details can be saved out epoch by epoch for later inspection, or we can monitor progress through them

In [None]:
# tf.keras.callbacks.ModelCheckpoint(
            # filepath,
            # monitor='val_loss', 
            # verbose=0, 
            # save_best_only=False,
            # save_weights_only=False,
            # mode='auto', 
            # save_freq='epoch',     
            # options=None, **kwargs)


In [None]:
# Let's save the model's weights
model_check_w = build_model(dense_units=256)

model_check_w.compile(optimizer='sgd',
                   loss='sparse_categorical_crossentropy',
                   metrics=['accuracy'])

model_check_w.fit(train_batches,
                 epochs=5,
                 validation_data=validation_batches,
                 verbose=2,
                 callbacks=[ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.h5', 
                                             verbose=1,
                                             ),])

In [None]:
os.listdir('/kaggle/working')

In [None]:
# Save the models
model_check_m = build_model(dense_units=256)
model_check_m.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])
  
model_check_m.fit(train_batches, 
          epochs=3, 
          validation_data=validation_batches, 
          verbose=2,
          callbacks=[ModelCheckpoint('model.h5', 
                                      verbose=1,
                                      save_best_only=True) # just save the best model
          ])

In [None]:
os.listdir('/kaggle/working')

## [Early stopping](https://keras.io/api/callbacks/early_stopping/)
- It can stop training when a monitored metric has stopped improving
- It can also be used the other way, if there's not enough improvement, it could end training,

In [None]:
# tf.keras.callbacks.EarlyStopping(
                             #     monitor='val_loss', 
                             #     min_delta=0, 
                             #     patience=0, 
                             #     verbose=0,
                             #     mode='auto', 
                             #     baseline=None, 
                             #     restore_best_weights=False)

In [None]:
model_early = build_model(dense_units=256)

model_early.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])

model_early.fit(train_batches,
                epochs=100,
                validation_data= validation_batches,
                verbose=2,
                callbacks=[EarlyStopping(patience=3, # model will check 3 epoch more after best score
                                         mode ='min',
                                         monitor='val_loss',
                                         # even if training will stop after 3 epoch, it get best weights
                                         restore_best_weights=True,
                                         verbose=1)])

# in this example we will get the weights from epoch 24 

In [None]:
os.listdir('/kaggle/working')

## [CSV Logger](https://keras.io/api/callbacks/csv_logger/)

- It streams epoch results to a CSV File

In [None]:
model = build_model(dense_units=256)

model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])
  
csv_file = 'training.csv'

model.fit(train_batches, 
          epochs=3, 
          validation_data=validation_batches, 
          callbacks=[CSVLogger(csv_file)
          ])

In [None]:
pd.read_csv(csv_file).head()

## [Learning Rate Scheduler](https://keras.io/api/callbacks/learning_rate_scheduler/)
- Updates the learning rate during training.
- When the learning rate is too large, gradient descent can inadvertently increase rather than decrease the training error.
- When the learning rate is too small, training is not only slower, but may become permanently stuck with a high training error.

![lr](https://www.jeremyjordan.me/content/images/2018/02/Screen-Shot-2018-02-24-at-11.47.09-AM.png)

In [None]:
model = build_model(dense_units=256)

model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])

#example 1
def scheduler(epoch):
    initial_lr = 0.01
    drop = 0.5
    epochs_drop = 1
    lr = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop))
    return lr 

# example 2
def scheduler_1(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)


model.fit(train_batches, 
          epochs=5, 
          validation_data=validation_batches, 
          callbacks=[LearningRateScheduler(scheduler, verbose=1)])

## [ReduceLROnPlateau](https://keras.io/api/callbacks/reduce_lr_on_plateau/)

- This callback monitors a quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced.

In [None]:
model = build_model(dense_units=256)

model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])
  
model.fit(train_batches, 
          epochs=50, 
          validation_data=validation_batches, 
          callbacks=[ReduceLROnPlateau(monitor='val_loss', 
                                       factor=0.2, verbose=1,
                                       patience=3, min_lr=0.001)])

# Keras Some Custom Callbacks

- The custom callbacks can still make use of all of the features of the built-in Keras call-backs

In [None]:
# Define the Keras model to add callbacks to
def get_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), 
                  loss='mean_squared_error', 
                  metrics=['mae'])
    return model

In [None]:
# load the MNIST data from Keras datasets API:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

### Callback 1

In [None]:
# define a simple custom callback to track the start and end of every batch of data. 
# During those calls, it prints the index of the current batch

class myCallback(tf.keras.callbacks.Callback):

    def on_train_batch_begin(self, batch, logs=None):
        print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

    def on_train_batch_end(self, batch, logs=None):
        print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

In [None]:
model = get_model()
_ = model.fit(x_train, y_train,
          batch_size=64,
          epochs=1,
          steps_per_epoch=5,
          verbose=0,
          callbacks=[myCallback()])

## Usage of `logs` dict
- The `logs` dict contains the loss value, and all the metrics at the end of a batch or epoch. 
- Example includes the loss and mean absolute error.

### Callback 2

- We measure in this example, the ratio between our validation loss and our training loss to detect overfitting. 
- when the ratio gets too high we should stop training

So, here we'll compute the ratio at the end of every epoch and if that ratio was higher than our threshold value, we can stop training.

In [None]:
class DetectOverfittingCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold):
        super(DetectOverfittingCallback, self).__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None):
        ratio = logs["val_loss"] / logs["loss"]
        print("Epoch: {}, Val/Train loss ratio: {:.2f}".format(epoch, ratio))

        if ratio > self.threshold:
            print("Stopping training...")
            self.model.stop_training = True

model = get_model()
_ = model.fit(x_train, y_train,
              validation_data=(x_test, y_test),
              batch_size=64,
              epochs=100,
              verbose=1,
              callbacks=[DetectOverfittingCallback(threshold=1.1)])

### Callback 3

In [None]:
# Visualization utilities
plt.rc('font', size=20)
plt.rc('figure', figsize=(15, 3))

def display_digits(inputs, outputs, ground_truth, epoch, n=10):
    plt.clf()

    plt.yticks([])
    plt.grid(None)
    inputs = np.reshape(inputs, [n, 28, 28])
    inputs = np.swapaxes(inputs, 0, 1)
    inputs = np.reshape(inputs, [28, 28*n])
    plt.imshow(inputs)
    plt.xticks([28*x+14 for x in range(n)], outputs)
    for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
        if outputs[i] == ground_truth[i]: 
            t.set_color('green') 
        else: 
            t.set_color('red')
    plt.grid(None)

In [None]:
GIF_PATH = './animation.gif'

In [None]:
class VisCallback(tf.keras.callbacks.Callback):
    def __init__(self, inputs, ground_truth, display_freq=10, n_samples=10):
        self.inputs = inputs
        self.ground_truth = ground_truth
        self.images = []
        self.display_freq = display_freq
        self.n_samples = n_samples

    def on_epoch_end(self, epoch, logs=None):
        # Randomly sample data
        indexes = np.random.choice(len(self.inputs), size=self.n_samples)
        X_test, y_test = self.inputs[indexes], self.ground_truth[indexes]
        predictions = np.argmax(self.model.predict(X_test), axis=1)

        # Plot the digits
        display_digits(X_test, predictions, y_test, epoch, n=self.display_freq)

        # Save the figure
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        image = Image.open(buf)
        self.images.append(np.array(image))

        # Display the digits every 'display_freq' number of epochs
        if epoch % self.display_freq == 0:
            plt.show()

    def on_train_end(self, logs=None):
        imageio.mimsave(GIF_PATH, self.images, fps=1)

In [None]:
def get_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(32, activation='linear', input_dim=784))
    model.add(tf.keras.layers.Dense(10, activation='softmax'))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
model = get_model()
model.fit(x_train, y_train,
          batch_size=64,
          epochs=20,
          verbose=0,
          callbacks=[VisCallback(x_test, y_test)])

In [None]:
SCALE = 60

# FYI, the format is set to PNG here to bypass checks for acceptable embeddings
IPyImage(GIF_PATH, format='png', width=15 * SCALE, height=3 * SCALE) 