In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_val, y_val) = mnist.load_data()
x_train, x_val = x_train / 255.0, x_val / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_val = x_val[..., tf.newaxis]

# Convert the labels to one-hot encoding
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_val = tf.keras.utils.to_categorical(y_val, 10)

# Batch and shuffle the data
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(100)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(100)


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)


# Create an instance of the model
model = MyModel()

# Choose an optimizer and loss function
loss_object = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# Select metrics to measure the loss and the accuracy of the model. 
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')


@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)


@tf.function
def val_step(images, labels):
    predictions = model(images)
    v_loss = loss_object(labels, predictions)

    val_loss(v_loss)
    val_accuracy(labels, predictions)


EPOCHS = 10

for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    val_loss.reset_states()
    val_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    for val_images, val_labels in val_ds:
        val_step(val_images, val_labels)

    print(
        f'Epoch {epoch + 1}, '
        f'Loss: {train_loss.result()}, '
        f'Accuracy: {train_accuracy.result() * 100}, '
        f'Val Loss: {val_loss.result()}, '
        f'Val Accuracy: {val_accuracy.result() * 100}'
    )

Epoch 1, Loss: 0.16048148274421692, Accuracy: 95.23833465576172, Val Loss: 0.07720496505498886, Val Accuracy: 97.44999694824219
Epoch 2, Loss: 0.04901094362139702, Accuracy: 98.48999786376953, Val Loss: 0.06016826257109642, Val Accuracy: 97.91999816894531
Epoch 3, Loss: 0.028814751654863358, Accuracy: 99.05166625976562, Val Loss: 0.05165985971689224, Val Accuracy: 98.29999542236328
Epoch 4, Loss: 0.014962086454033852, Accuracy: 99.52999877929688, Val Loss: 0.0494658425450325, Val Accuracy: 98.32999420166016
Epoch 5, Loss: 0.011475426144897938, Accuracy: 99.62999725341797, Val Loss: 0.0589199922978878, Val Accuracy: 98.31999969482422
Epoch 6, Loss: 0.008025946095585823, Accuracy: 99.74666595458984, Val Loss: 0.05263926088809967, Val Accuracy: 98.50999450683594
Epoch 7, Loss: 0.005170242860913277, Accuracy: 99.8550033569336, Val Loss: 0.05774210765957832, Val Accuracy: 98.47000122070312
Epoch 8, Loss: 0.005907862912863493, Accuracy: 99.7933349609375, Val Loss: 0.057167455554008484, Val A