# Keras Callbacks: Extending Their Scope And Usage

As a transitioning piece between our to groups of blogs, this blog is still work in progress.

It concerns the various `callbacks` that we can register during training sessions.

As automating our sessions is an important objective of ours, specifically to enable us to fine-tune our training `hyper-parameters`, we will be adding to this blog as the rest of the next group's blogs materialize.

Just as before, we need to prep our environment to run any meaningful code:

In [1]:
from datetime import datetime
import pathlib as pth
import tensorflow as tf
import dataset as qd
import custom as qc
ks = tf.keras
kl = ks.layers

And now we are ready to define our model.

As this blog focuses on the actual training process, our model can be reused directly from a previous blog:

In [2]:
def model_for(ps):
    x = [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    x += [ks.Input(shape=(), dtype='int32'), ks.Input(shape=(), dtype='int64')]
    y = qc.ToRagged()(x)
    y = qc.Frames(ps)(y)
    embed = qc.Embed(ps)
    ye = qc.Encode(ps)(embed(y[:2]))
    yd = qc.Decode(ps)(embed(y[2:]) + [ye[0]])
    y = qc.Debed(ps)(yd)
    m = ks.Model(name='callbacks', inputs=x, outputs=y)
    m.compile(optimizer=ps.optimizer, loss=ps.loss, metrics=[ps.metric])
    print(m.summary())
    return m

Once the model is defined, we adjust our main calling function.

At this point we define our `callbacks` that should be kept "in loop" during our training session.

Initially we still want to include the standard Keras TensorBoard callbacks.

Additionally, we want to roll our own checkpointing. We choose to use the latest `Checkpoint` and `CheckpointManager` classes (see our [blog](./trackable.html) regarding this topic).

For this we define a custom Keras `Callback` class called `CheckpointCB`. As this callback is only used to save or update our current checkpoint file, it only needs to override the `on_epoch_end` callback.

In the override it simply calls the manager's `save` method.  

To be expanded...

In [3]:
def main_graph(ps, ds, m):
    b = pth.Path('/tmp/q')
    b.mkdir(parents=True, exist_ok=True)
    lp = datetime.now().strftime('%Y%m%d-%H%M%S')
    lp = b / f'logs/{lp}'
    c = tf.train.Checkpoint(model=m)
    mp = b / 'model' / f'{m.name}'
    mgr = tf.train.CheckpointManager(c, str(mp), max_to_keep=3)
    # if mgr.latest_checkpoint:
    #     vs = tf.train.list_variables(mgr.latest_checkpoint)
    #     print(f'\n*** checkpoint vars: {vs}')
    c.restore(mgr.latest_checkpoint).expect_partial()

    class CheckpointCB(ks.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            mgr.save()

    cbs = [
        CheckpointCB(),
        ks.callbacks.TensorBoard(
            log_dir=str(lp),
            histogram_freq=1,
        ),
    ]
    m.fit(ds, callbacks=cbs, epochs=ps.num_epochs)

We may also need to update our parameters as they relate to our "callback objectives".

To be expanded...

In [4]:
params = dict(
    dim_batch=5,
    dim_dense=150,
    dim_hidden=6,
    dim_stacks=2,
    dim_vocab=len(qd.vocab),
    loss=ks.losses.SparseCategoricalCrossentropy(from_logits=True),
    metric=ks.metrics.SparseCategoricalCrossentropy(from_logits=True),
    num_epochs=5,
    num_shards=2,
    optimizer=ks.optimizers.Adam(),
    width_dec=15,
    width_enc=25,
)

And now we are ready to start our training session.

We can confirm the model's layers and connections. We can easily adjust the parameters to tailor the length of the sessions to our objectives.

In [5]:
ps = qd.Params(**params)
main_graph(ps, qc.dset_for(ps), model_for(ps))

Model: "callbacks"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None,)]            0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None,)]            0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None,)]            0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None,)]            0                                            
__________________________________________________________________________________________

A quick `ls` into our `/tmp/q/model/callbacks` checkpoint directory shows that our manager is in fact updating the checkpoint files and it is keeping only the last three, just as we expect.

To be expanded...