In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy, sparse_categorical_crossentropy
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPool2D, Flatten, Dense, Dropout
from sklearn.metrics import classification_report
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

(ds_train, ds_test), ds_info = tfds.load(
    name='cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

label_train = []  # [1]
for image, label in tfds.as_numpy(ds_train):
    label_train.append(label)

label_test = []
for image, label in tfds.as_numpy(ds_test):
    label_test.append(label)


print(label_train)
print(label_test)

# flatten is NumPy's method, not Python's list. [2]
# label_train, label_test = label_train.flatten(), label_test.flatten()

# label_train = np.array(label_train)
# label_test = np.array(label_test)

# label_train, label_test = label_train.flatten(), label_test.flatten()

CLASS_NAMES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# TFDS provide the images as tf.uint8, while the model expect tf.float32, so normalize images
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label


AUTO = tf.data.experimental.AUTOTUNE

# Build training pipeline
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTO)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(AUTO)

# Build evaluation pipeline
ds_test = ds_test.map(normalize_img, num_parallel_calls=AUTO)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(AUTO)

print(ds_test)

# Plug the input pipeline into Keras.
model = Sequential([
    Conv2D(filters=96, kernel_size=(11, 11), strides=(4, 4), activation=tf.nn.relu,
           data_format='channels_last', input_shape=(32, 32, 3)),
    BatchNormalization(),
    MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same'),
    Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1), activation=tf.nn.relu, padding="same"),
    BatchNormalization(),
    MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same'),
    Conv2D(filters=384, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu, padding="same"),
    BatchNormalization(),
    Conv2D(filters=384, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu, padding="same"),
    BatchNormalization(),
    Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), activation=tf.nn.relu, padding="same"),
    BatchNormalization(),
    MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same'),
    Flatten(),
    Dense(4096, activation=tf.nn.relu),
    Dropout(0.5),
    Dense(4096, activation=tf.nn.relu),
    Dropout(0.5),
    Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer=Adam(0.001),
              loss=sparse_categorical_crossentropy,
              metrics=['accuracy'])

checkpoint_callback = ModelCheckpoint(filepath='my_net_early_stopping.h5',
                                      save_best_only=True,
                                      save_freq='epoch',
                                      save_weights_only=False,
                                      mode='auto',
                                      monitor='accuracy',
                                      verbose=0)

early_stopping_callback = EarlyStopping(patience=10, restore_best_weights=True)

history = model.fit(ds_train, epochs=20, callbacks=[checkpoint_callback,
                                                    early_stopping_callback],
                    validation_data=(ds_test))

evaluation = model.evaluate(ds_test)

print(evaluation)

model.save("my_net_early_stopping.h5")

model = tf.keras.models.load_model("my_net_early_stopping.h5")

print('\n', 30*"=", '\n')

y_prediction = model.predict(ds_test)

print(y_prediction)

print('\n', 30*"=", '\n')

# Get most likely class. [1]
y_prediction_bool = np.argmax(y_prediction, axis=1)

print(y_prediction_bool)

print(classification_report(y_true=label_test, y_pred=y_prediction_bool, target_names=CLASS_NAMES))



# Reference:
# 1. https://github.com/keras-team/keras/issues/2607#issuecomment-302365916
# 2. https://stackoverflow.com/a/65874930/14900011

[7, 8, 4, 4, 6, 5, 2, 9, 6, 6, 9, 9, 3, 0, 8, 7, 9, 0, 4, 9, 0, 8, 6, 4, 2, 8, 8, 7, 0, 8, 4, 2, 3, 7, 0, 5, 4, 3, 8, 1, 5, 9, 4, 9, 8, 6, 9, 7, 7, 7, 3, 6, 3, 8, 3, 6, 1, 1, 7, 0, 9, 0, 0, 4, 6, 3, 2, 7, 4, 5, 2, 7, 4, 3, 8, 4, 5, 0, 6, 3, 7, 3, 1, 0, 5, 7, 3, 3, 7, 4, 9, 5, 2, 3, 2, 1, 8, 2, 7, 9, 5, 9, 8, 9, 0, 4, 5, 3, 4, 0, 2, 3, 9, 5, 3, 8, 3, 8, 7, 8, 2, 8, 5, 7, 7, 0, 3, 5, 5, 3, 9, 1, 0, 3, 3, 0, 2, 9, 2, 0, 7, 0, 1, 5, 3, 4, 4, 9, 0, 1, 1, 0, 2, 3, 5, 5, 3, 1, 5, 8, 3, 5, 8, 6, 7, 7, 4, 7, 8, 4, 7, 8, 7, 8, 3, 6, 1, 5, 2, 7, 6, 8, 2, 3, 7, 8, 5, 8, 4, 1, 1, 2, 6, 7, 0, 3, 0, 7, 1, 3, 2, 4, 1, 6, 8, 4, 9, 7, 4, 1, 5, 8, 9, 4, 8, 9, 5, 2, 9, 6, 8, 3, 8, 0, 8, 6, 5, 9, 3, 0, 2, 4, 2, 9, 9, 3, 8, 3, 0, 5, 8, 3, 8, 7, 3, 8, 3, 9, 9, 7, 3, 6, 3, 1, 6, 4, 2, 1, 7, 4, 3, 4, 8, 6, 6, 1, 0, 9, 0, 7, 1, 1, 2, 0, 5, 8, 8, 8, 3, 7, 7, 5, 7, 7, 5, 6, 8, 1, 3, 8, 6, 0, 3, 7, 0, 9, 0, 4, 6, 0, 0, 2, 0, 4, 3, 9, 7, 0, 9, 8, 8, 8, 3, 6, 7, 6, 5, 7, 7, 3, 7, 0, 5, 0, 1, 5, 6, 8, 0, 4, 4, 4, 6, 