In [1]:
import os
import tensorflow as tf
import matplotlib.pyplot
import numpy as np
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers

from tensorflow.keras.datasets import mnist

from tensorflow.keras.layers import (
    BatchNormalization, Conv2D, MaxPooling2D, Activation, Flatten, Dropout, Dense)

In [2]:
(ds_train,ds_test),ds_info = tfds.load(
    "mnist",
    split=['train','test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True

)

In [3]:
def normalize_img(image, label):
    """Normalizes images"""
    return tf.cast(image, tf.float32) / 255.0, label

In [4]:

AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 128

# Setup for train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)


In [5]:
model = keras.Sequential(
    [
        keras.Input((28, 28, 1)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Flatten(),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

In [6]:
#call backs
save_callback = keras.callbacks.ModelCheckpoint(
    "checkpoint/", save_weights_only=True, monitor="train_acc", save_best_only=False,
)

lr_scheduler = keras.callbacks.ReduceLROnPlateau(
    monitor="loss", factor=0.1, patience=3, mode="max", verbose=1
)


class OurOwnCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs.get("accuracy") > 1:
            print("Accuracy over 70%, quitting training")
            self.model.stop_training = True


In [7]:
model.compile(
    optimizer=keras.optimizers.Adam(0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

In [8]:
model.fit(
    ds_train,
    epochs=10,
    callbacks=[save_callback, lr_scheduler, OurOwnCallback()],
    verbose=2,
)

Epoch 1/10
469/469 - 13s - loss: 0.1412 - accuracy: 0.9578 - lr: 0.0100 - 13s/epoch - 27ms/step
Epoch 2/10
469/469 - 10s - loss: 0.0562 - accuracy: 0.9826 - lr: 0.0100 - 10s/epoch - 20ms/step
Epoch 3/10
469/469 - 9s - loss: 0.0341 - accuracy: 0.9886 - lr: 0.0100 - 9s/epoch - 19ms/step
Epoch 4/10

Epoch 4: ReduceLROnPlateau reducing learning rate to 0.0009999999776482583.
469/469 - 9s - loss: 0.0240 - accuracy: 0.9918 - lr: 0.0100 - 9s/epoch - 19ms/step
Epoch 5/10
469/469 - 9s - loss: 0.0088 - accuracy: 0.9971 - lr: 1.0000e-03 - 9s/epoch - 20ms/step
Epoch 6/10
469/469 - 10s - loss: 0.0037 - accuracy: 0.9993 - lr: 1.0000e-03 - 10s/epoch - 21ms/step
Epoch 7/10

Epoch 7: ReduceLROnPlateau reducing learning rate to 9.999999310821295e-05.
469/469 - 10s - loss: 0.0022 - accuracy: 0.9997 - lr: 1.0000e-03 - 10s/epoch - 21ms/step
Epoch 8/10
469/469 - 10s - loss: 0.0014 - accuracy: 0.9999 - lr: 1.0000e-04 - 10s/epoch - 21ms/step
Epoch 9/10
469/469 - 9s - loss: 0.0013 - accuracy: 0.9999 - lr: 1.00

<keras.callbacks.History at 0x1f9b99727f0>