<a href="https://colab.research.google.com/github/ricky-kiva/dl-tensorflow-intro/blob/main/2_l2_tf_callback.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Control Training using Callbacks

In [2]:
# import libraries
import tensorflow as tf

Load dataset and normalize

In [3]:
# instantiate dataset API
fmnist = tf.keras.datasets.fashion_mnist

# load dataset
(x_train, y_train), (x_test, y_test) = fmnist.load_data()

# normalize pixel values
x_train, x_test = (x_train / 255.0), (x_test / 255.0)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


Create callback class by defining a class that inherits the [tf.keras.callbacks.Callback](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) base class

From there, one can override available methods to set where the callback will be executed. In this example, we will use [on_epoch_end()](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_epoch_end)

Args:
*   epoch (integer) - index of epoch (required but unused in the function definition below)
*   logs (dict) - metric results from the training epoch


In [13]:
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):

    # Check the loss
    if(logs.get('loss') < 0.4):

      # Stop if threshold is met
      print("\nLoss is lower than 0.4 so cancelling training!")
      self.model.stop_training = True

In [14]:
# define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# compile the model
model.compile(optimizer=tf.optimizers.Adam(),
              loss=tf.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

In [15]:
# instantiate class
callbacks = myCallback()

# train the model with callback
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Epoch 1/10
Epoch 2/10
Loss is lower than 0.4 so cancelling training!


<keras.callbacks.History at 0x79819008f9a0>