# tf.keras.callbacks

tf.keras.callbacks is a powerful tool for training. The main purpose of callbacks is to do several operations, such as save model and change leraning rate, between 2 epochs.

There are already several build-in callback function. If we need a custom callback function, new class can be defined by inheriting the tf.keras.callbacks.Callbacks.

There are two attributes in Callbacks class: params and model. "params" is a dict for the training parameters. "model" is the reference of current model.

## 1. build-in callback function

    BaseLogger: metrics on every batch, default added
    Histroty: metrics on every epoch, default added
    EarlyStopping:
    TensorBoard: vis logs for Tensorboard
    ModelCheckpoint: save model after every epoch
    ReduceLROnPlateau: if the metrics do not change, it starts to reduce the learning rate
    TerminateOnNaN: terminates the training if loss is NaN
    LearningRateSceduler: control the lr after each epoch
    CSVLogger: save log to CSV for each epoch
    ProgbarLogger: print logs to stdout after each epoch

## 2. custom callback function

Here is the example of LearningRateScheduler.


In [2]:
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K

class LearningRateScheduler(callbacks.Callback):

    def __init__(self, schedule, verbose=0):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        try:  
            lr = float(K.get_value(self.model.optimizer.lr))
            lr = self.schedule(epoch, lr)
        except TypeError:  # Support for old API for backward compatibility
            lr = self.schedule(epoch)
        if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
            raise ValueError('The dtype of Tensor should be float')
        K.set_value(self.model.optimizer.lr, K.get_value(lr))
        if self.verbose > 0:
            print('\nEpoch %05d: LearningRateScheduler reducing learning '
                 'rate to %s.' % (epoch + 1, lr))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)