<a href="https://colab.research.google.com/github/s1scottd/ColabStorage/blob/main/TensorFlow%20Examples/Callbacks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from tensorflow import keras
from keras.datasets import mnist
from keras import layers

**Read in the mnist dataset.**

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
print(f"x_train shape: {x_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"x_test shape:  {x_test.shape}")
print(f"y_test shape:  {y_test.shape}")

**Reshape the data to (number of samples, 28, 28, 1).  Normalize the data as floating point.**

In [None]:
def reshape_and_nornmalize(data):
  shape = data.shape
  data = np.reshape(data,(shape[0],28,28,1))
  data = data/255.
  return data

In [None]:
x_train = reshape_and_nornmalize(x_train)
x_test = reshape_and_nornmalize(x_test) 

print(f"x_train shape: {x_train.shape}")
print(f"x_test shape:  {x_test.shape}")

**Create the model and compile it.**

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation = "relu"),
    tf.keras.layers.Dense(10, activation="softmax")
])

model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

**Create a callback that will stop fitting the model once it reaches 99.5% accuracy.**

In [None]:
class my_callback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=[]):
    if (logs.get("accuracy") >= 0.995):
      self.model.stop_training = True
      print(f"\nReached 99.5% accuracy so cancelling training after {epoch} epochs.\n")

**Fit the model.**

In [None]:
print(f"\nFit Model:\n")

callbacks = my_callback()

history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

**Evaluate the model**

In [None]:
model.evaluate(x_test, y_test)

**Use the model to predict**

In [None]:
classifications = model.predict(x_test)

for i in range(0,10):
  digit = x_test[i]
  plt.imshow(digit, cmap=plt.cm.binary)
  plt.show()
  print(f"\nx_test[{i}]: {classifications[i]}")
  print(f"y_test[{i}]: {y_test[i]}\n\n")