In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True
)

In [3]:
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(512)
ds_train = ds_train.cache()
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(512)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [4]:
from keras import models
from keras import layers

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(32, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='sigmoid'))    

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [5]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

filename = 'MNIST.h5'

checkpoint = ModelCheckpoint(filename,             
                             monitor='val_loss',   
                             verbose=1,            
                             save_best_only=True,  
                             mode='auto'           
                            )

earlystopping = EarlyStopping(monitor='val_loss',  
                              patience=10,         
                             )


In [6]:
history =  model.fit(
    ds_train,
    epochs=60,
    validation_data=ds_test,
    callbacks=[checkpoint, earlystopping],
)

Epoch 1/60

Epoch 00001: val_loss improved from inf to 0.15027, saving model to MNIST.h5
Epoch 2/60

Epoch 00002: val_loss improved from 0.15027 to 0.09139, saving model to MNIST.h5
Epoch 3/60

Epoch 00003: val_loss improved from 0.09139 to 0.06323, saving model to MNIST.h5
Epoch 4/60

Epoch 00004: val_loss improved from 0.06323 to 0.05368, saving model to MNIST.h5
Epoch 5/60

Epoch 00005: val_loss improved from 0.05368 to 0.04678, saving model to MNIST.h5
Epoch 6/60

Epoch 00006: val_loss improved from 0.04678 to 0.04488, saving model to MNIST.h5
Epoch 7/60

Epoch 00007: val_loss improved from 0.04488 to 0.04345, saving model to MNIST.h5
Epoch 8/60

Epoch 00008: val_loss improved from 0.04345 to 0.04113, saving model to MNIST.h5
Epoch 9/60

Epoch 00009: val_loss improved from 0.04113 to 0.03986, saving model to MNIST.h5
Epoch 10/60

Epoch 00010: val_loss improved from 0.03986 to 0.03765, saving model to MNIST.h5
Epoch 11/60

Epoch 00011: val_loss improved from 0.03765 to 0.03650, savi