# Using Callbacks to Control Training

In this lab, you will use the [Callbacks API](https://keras.io/api/callbacks/) to stop training when a specified metric is met. This is a useful feature so you won't need to complete all epochs when this threshold is reached. For example, if you set 1000 epochs and your desired accuracy is already reached at epoch 200, then the training will automatically stop. Let's see how this is implemented in the next sections.

## Load and Normalize the Fashion MNIST dataset

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense

In [2]:
fmnist = tf.keras.datasets.fashion_mnist

In [3]:
(X_train, y_train), (X_test, y_test) = fmnist.load_data()

X_train = X_train/255.0
X_test = X_test/255.0

## Creating a Callback class

You can create a callback by defining a class that inherits the [tf.keras.callbacks.Callback](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) base class. From there, you can define available methods to set where the callback will be executed. For instance below, you will use the [on_epoch_end()](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_epoch_end) method to check the loss at each training epoch.

In [5]:
class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        '''
        Halts the training after reaching 60 percent accuracy

        Args:
          epoch (integer) - index of epoch (required but unused in the function definition below)
          logs (dict) - metric results from the training epoch
        '''

        # check the accuracy
        if(logs.get('loss') < 0.4):

            # stop if threshold is met
            print('\n loss is lower than 0.4. So cancelling training!')
            self.model.stop_training=True

# Instantiate the class
callbacks = myCallback()

## Define and compile the model

In [8]:
# define the model
model = Sequential()

model.add(Flatten(input_shape=(28,28)))
model.add(Dense(units=512, activation='relu'))
model.add(Dense(units=10, activation='softmax'))

# compile the model
model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

### Train the model

Now you are ready to train the model. To set the callback, simply set the `callbacks` parameter to the `myCallback` instance you declared before. Run the cell below and observe what happens.

In [9]:
# Train the model with a callback
model.fit(X_train, y_train, epochs=10, callbacks=[callbacks])

Epoch 1/10
Epoch 2/10
 loss is lower than 0.4. So cancelling training!


<keras.callbacks.History at 0xedb7c402c8>

You will notice that the training does not need to complete all 10 epochs. By having a callback at each end of the epoch, it is able to check the training parameters and compare if it meets the threshold you set in the function definition. In this case, it will simply stop when the loss falls below `0.40` after the current epoch.

*Optional Challenge: Modify the code to make the training stop when the accuracy metric exceeds 60%.*

That concludes this simple exercise on callbacks!