# Custom Callbacks


## Build Simple Model

In [7]:
import tensorflow as tf

x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = tf.constant([[12.0], [20.0]])

x_val = tf.constant([[5,6,7], [8,9,10]])
y_val = tf.constant([[24.0], [27.0]])

model = tf.keras.Sequential()

model.add(tf.keras.layers.Dense(units=10, input_shape=[3]))

model.compile(optimizer='sgd', loss='mean_squared_error')

## Custom Callbacks 1

In [8]:
import datetime

class TimerCallback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, batch, logs=None):
        print(f'Training: batch{batch}, begin at{datetime.datetime.now()}')
        
    def on_epoch_end(self, batch, logs=None):
        print(f'Training: batch{batch}, end at{datetime.datetime.now()}')
        
callback = TimerCallback()
model.fit(x, y, epochs=3, callbacks=[callback])

Training: batch0, begin at2024-05-19 15:01:54.902834
Epoch 1/3
Training: batch1, begin at2024-05-19 15:01:55.566717
Epoch 2/3
Training: batch2, begin at2024-05-19 15:01:55.575720
Epoch 3/3


<keras.src.callbacks.History at 0x2cffe282c40>

## Custom Callback 2

In [11]:
class DetectOverfittingCallback(tf.keras.callbacks.Callback):
    def __init__(self, treshold):
        super(DetectOverfittingCallback, self).__init__()
        self.treshold = treshold
        
    def on_epoch_end(self, epoch, logs=None):
        ratio = logs['val_loss']/logs['loss']
        print(f'Epoch: {epoch}, val/train loss ratio: {ratio}')
        
        if ratio>self.treshold:
            print('Epoch: {epoch}, Stop training')
            self.model.stop_training = True

model.fit(x, y, epochs=100, validation_data=(x_val, y_val), callbacks=[DetectOverfittingCallback(treshold=0.5)])

Epoch 1/100
Epoch: {epoch}, Stop training


<keras.src.callbacks.History at 0x2cffe3dcbb0>

## Custom Callback 3

In [None]:
import numpy as np

class VisCallback(tf.keras.callbacks.Callback):
    def __init__(self, inputs, ground_truth, display_freq=10, n_samples=10):
        super(VisCallback, self).__init__()
        self.inputs = inputs
        self.ground_truth = ground_truth
        self.images = []
        self.display_freq = display_freq
        self.n_samples = n_samples
        
    def on_epoch_end(self, epoch, logs=None):
        indexes = np.random.choice(len(self.inputs), self.n_samples)
        x_test, y_test = self.inputs[indexes], self.ground_truth[indexes]
        prediction = np.argmax(self.model.predict(x_test), axis =1)
    
        for i in range(self.n_samples):
            self.images.append((x_test[i], y_test[i], prediction[i]))
        
        if epoch % self.display_freq == 0:
            self.display_images()