Train a model using a custom training loop to tackle the Fashion MNIST dataset (see Chapter 10):
a. Display the epoch, iteration, mean training loss, and mean accuracy over each epoch (updated at each iteration), as well as the validation loss and accuracy at the end of each epoch.

b. Try using a different optimizer with a different learning rate for the upper layers and the lower layers.

In [1]:
import tensorflow as tf

In [9]:
import pandas as pd

In [19]:
(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train_full, x_test = x_train_full / 256.0, x_test / 256.0
x_train, x_valid = x_train_full[:-5000], x_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

In [22]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(100, activation='relu', kernel_initializer='he_normal'),
    tf.keras.layers.Dense(30, activation='relu', kernel_initializer='he_normal'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [24]:
model.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    optimizer=tf.keras.optimizers.legacy.Nadam(learning_rate=0.0005),
    metrics=[tf.keras.metrics.sparse_categorical_crossentropy]
)

In [25]:
model.fit(x_train, y_train, epochs=10, validation_data=(x_train, y_train))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x29c1f9ed0>

In [26]:
import numpy as np

In [27]:
def random_batch(X, y, batch_size=32):
    idx = np.random.randint(len(X), size=batch_size)
    return X[idx], y[idx]

def print_status_bar(step, total, loss, metrics=None):
    metrics = ' - '.join([f'{m.name}: {m.result():.4f}' for m in [loss] + (metrics or [])])
    end = '' if step < total else '\n'
    print(f'\r{step}/{total} - ' + metrics, end=end)

In [33]:
n_steps = x_train.shape[0] // 32
n_epochs = 10
loss_fn = tf.keras.losses.sparse_categorical_crossentropy
optimizer = tf.keras.optimizers.legacy.Nadam(learning_rate=0.0005)
mean_loss = tf.keras.metrics.Mean(name='mean_loss')
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

In [34]:
for epoch in range(1, n_epochs + 1):
    print(f'Epoch {epoch}/{n_epochs}')
    for step in range(1, n_steps + 1):
        x_batch, y_batch = random_batch(x_train, y_train)
        with tf.GradientTape() as tape:
            y_pred = model(x_batch, training=True)
            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
            # We are using add_n since model.losses returns a scalar tensor per loss
            # (in this case the model has regularization loss per layer)
            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        # Perform a gradient descent step
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        # Add weight constraints
        for variable in model.trainable_variables:
            if variable.constraint is not None:
                variable.assign(variable.constraint(variable))
        # Calculate the loss and metrics for the epoch
        mean_loss(loss)
        for metric in metrics:
            metric(y_batch, y_pred)

        print_status_bar(step, n_steps, mean_loss, metrics)

    # Reset the metrics every epoch
    for metric in [mean_loss] + metrics:
        metric.reset_state()

Epoch 1/10
1875/1875 - mean_loss: 0.1820 - sparse_categorical_accuracy: 0.9324
Epoch 2/10
1875/1875 - mean_loss: 0.1774 - sparse_categorical_accuracy: 0.9350
Epoch 3/10
1875/1875 - mean_loss: 0.1706 - sparse_categorical_accuracy: 0.9359
Epoch 4/10
1875/1875 - mean_loss: 0.1702 - sparse_categorical_accuracy: 0.9376
Epoch 5/10
1875/1875 - mean_loss: 0.1675 - sparse_categorical_accuracy: 0.9395
Epoch 6/10
1875/1875 - mean_loss: 0.1631 - sparse_categorical_accuracy: 0.9397
Epoch 7/10
1875/1875 - mean_loss: 0.1576 - sparse_categorical_accuracy: 0.9420
Epoch 8/10
1875/1875 - mean_loss: 0.1557 - sparse_categorical_accuracy: 0.9432
Epoch 9/10
1875/1875 - mean_loss: 0.1541 - sparse_categorical_accuracy: 0.9423
Epoch 10/10
1875/1875 - mean_loss: 0.1532 - sparse_categorical_accuracy: 0.9444
