# 7.4 Writing your own training and validation loops

- `fit()` gives a nice balance between ease of use and flexibility.
- But this built-in workflow is solely focused on *supervised learning*, we might encounter *generative learning*, *self-supervised learning* or *reinforcement learning* and we might need to write our own custom training logic.

Contents of a typical training loop:
1. Run the forward pass inside a gradient tape to get loss value.
2. Get gradients of loss w.r.t. weights.
3. Update weights as to lower the loss.

This is essentially what `fit()` does under the hood.

Let's learn to implement `fit()` from scratch to learn how to write any training algorithm we may need.

## 7.4.1 Training versus inference

In training loops that we've seen so far:
* Step 1 (forward pass) is done via `predictions = model(inputs)`
* Step 2 (gradients retrieval) is done via `gradients = tape.gradient(loss, model.weights)`

Some Keras layers expose a `training` Boolean argument in their `call()` method, so do Functional and Sequential models.

* Remember to pass `training=True` when aclling a Keras model during the forward pass.
* In retrieving gradients, remember to use `model.trainable_weights` instead of `model.weights` in `GradientTape`.

> Some weights are non-trainable which are meant to be updated during forward pass.

So, a supervised-learning training step must look like this:

```python
def train_step(inputs, targets):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    loss = loss_fn(targets, predictions)

  gradients = tape.gradient(loss, model.trainable_weights)
  optimizer.apply_gradients(zip(model.trainable_weight, gradients))

```

## 7.4.2 Low-level usage of metrics

- For low-level training loops (loops where you have manual control over the training process), we might wanna leverage Keras metrics.
- Metrics API: simply call `update_state(y_true, y_pred)` for each batch of targets and preds and use `result()` to query current metric value:

In [2]:
import tensorflow
from tensorflow import keras

metric = keras.metrics.SparseCategoricalAccuracy()
targets = [0, 1, 2]
predictions = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
metric.update_state(targets, predictions)
current_result = metric.result()
print(f"result: {current_result:.2f}")

result: 1.00


Tracking average of scalar value like the model's loss. Do this via `keras.metrics.Mean` metric:

In [3]:
values = [0, 1, 2, 3, 4]
mean_tracker = keras.metrics.Mean()
for value in values:
  mean_tracker.update_state(value)
print(f"mean of values: {mean_tracker.result():.2f}")

mean of values: 2.00


Remember to use `metric.reset_state()` to reset current results (at start of training epoch or start of evaluation).

## 7.4.3 A complete training and evaluation loop


### Writing a step-by-step training loop: the training step function

In [4]:
from tensorflow.keras.datasets import mnist

def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

In [7]:
tf = tensorflow
from tensorflow.keras import layers

In [8]:
model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy() # prepare the loss func
optimizer = keras.optimizers.RMSprop() # prepare the optimizer
metrics = [keras.metrics.SparseCategoricalAccuracy()] # list of metrics to monitor
loss_tracking_metric = keras.metrics.Mean() # mean metric to keep track of loss average

def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True) # forward pass. Note: training=True
        loss = loss_fn(targets, predictions)
    gradients = tape.gradient(loss, model.trainable_weights) # back pass. Note: trainable_weights
    optimizer.apply_gradients(zip(gradients, model.trainable_weights)) # back pass. Note: trainable_weights

    # keep track of the metrics
    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs[metric.name] = metric.result()

    # keep track of the loss average
    loss_tracking_metric.update_state(loss)
    logs["loss"] = loss_tracking_metric.result()
    return logs   # return the current values of the metrics and the loss

Gotta reset the state of our metrics before the start of each epoch and before running the evaluation. Write a utility function for that.

### Resetting the metrics

In [9]:
def reset_metrics():
  for metric in metrics:
    metric.reset_state()
  loss_tracking_metric.reset_state()

Now let's lay out the complete training loop.

### The loop itself

In [11]:
# let's load the data first
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [12]:
# this object turns NumPy data into an iterator that iterates over the data in batches of size 32
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3
for epoch in range(epochs):
    reset_metrics()
    for inputs_batch, targets_batch in training_dataset:
        logs = train_step(inputs_batch, targets_batch)
    print(f"Results at the end of epoch {epoch}")
    for key, value in logs.items():
        print(f"...{key}: {value:.4f}")

Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9152
...loss: 0.2894
Results at the end of epoch 1
...sparse_categorical_accuracy: 0.9534
...loss: 0.1602
Results at the end of epoch 2
...sparse_categorical_accuracy: 0.9636
...loss: 0.1317


### Now let's write the evaluation loop step-by-step

In [13]:
# test_step() is just train_step() without the updating steps
def test_step(inputs, targets):
    predictions = model(inputs, training=False) # note that training=False here
    loss = loss_fn(targets, predictions)

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)
print("evaluation results:")
for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

evaluation results:
...val_sparse_categorical_accuracy: 0.9680
...val_loss: 0.1168


Let's look at some more features supported by `fit()` and `evaluate()`, including large-scale distributed computation, performance optimizations. Let's look at one of these optimizations: TensorFlow function compilation.

## 7.4.4 Make it fast with tf.function

- Our custom loops are running pretty slower than the built-in methods.
- Because, by default, TF code is executed line by line, *eagerly*, just like regular Python code.
- It's good from debugging POV but not from a performance POV.
- It's more performant to *compile* your TF code into a *computation graph* that can be globally optimized that line-by-line code cannot.
- Just add a `@tf.function` decorator to do this.

### `@tf.function` decorator to our evaluation-step function

In [14]:
@tf.function # the only line that changed
def test_step(inputs, targets):
    predictions = model(inputs, training=False) # note that training=False here
    loss = loss_fn(targets, predictions)

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)
print("evaluation results:")
for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

evaluation results:
...val_sparse_categorical_accuracy: 0.9680
...val_loss: 0.1168


This took less than half of the previous method!

> When debugging, run it eagerly, but once you know the code is working, add `@tf.function`.

## 7.4.5 Leveraging fit() with a custom training loop

- Writing our own training loop (that we've been doing previously) gives us flexibility but we need to write a lot of code and miss out on features.
- There is a common ground. Providing a custom training step function and let the framework do the rest.

* Create a new class that subclasses the `keras.Model`
* Override the `train_step(self, data)`
* Implement a `metrics` property that tracks model's `Metrics` instances. Enables model to automatically call `reset_state()` on metrics at start of a call to `evaluate()`, so we don't have to do this by hand.

In [15]:
#these metric object will be used to track the average of per-batch losses during training and eval
loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")

class CustomModel(keras.Model):
    def train_step(self, data): # overriding the train_step method
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True) # using self(inputs, training=True) instead of model(inputs, training=True) since our model is the class itself
            loss = loss_fn(targets, predictions)
        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))

        loss_tracker.update_state(loss) #update the loss tracker metric that tracks the average of the loss
        return {"loss": loss_tracker.result()} # return the average loss so far

    # any metric to be reset across epochs should be listed here
    @property
    def metrics(self):
        return [loss_tracker]

In [16]:
# let's instantiate our custom model

inputs = keras.Input(shape=(28 * 28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop()) # loss is already defined outside the model
model.fit(train_images, train_labels, epochs=3)

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 5ms/step - loss: 0.4548
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 5ms/step - loss: 0.1608
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 7ms/step - loss: 0.1320


<keras.src.callbacks.history.History at 0x7b081f3cb390>

- Can do this whether building Sequential, Functional or subclass models.
- Don't need to use `@tf.function` when you override `train_setp`, the framework does it for us.

After calling `compile()`, we get access to:
- `self.compiled_loss`: loss function passed to compile.
- `self.compiled_metrics`: wrapper for list of metrics we passed, allows to call `self.compiled_metrics.update_state()` to update all metrics at once.
- `self.metrics`: actual list of metrics passed to `compile()`.

Thus:

In [18]:
class CustomModel(keras.Model):
    def train_step(self, data):
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = self.compiled_loss(targets, predictions) # compute loss via self.compiled_loss

        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
        self.compiled_metrics.update_state(targets, predictions) #update model's metrics via self.compiled_metrics
        return {m.name: m.result() for m in self.metrics} #return a dict mapping metric names to their current value

In [19]:
# try it:
inputs = keras.Input(shape=(28 * 28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop(),
              loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=3)

Epoch 1/3


```
for metric in self.metrics:
    metric.update_state(y, y_pred)
```

  return self._compiled_metrics_update_state(


[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 7ms/step - sparse_categorical_accuracy: 0.8654 - loss: 0.1000
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - sparse_categorical_accuracy: 0.9526 - loss: 0.1000
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - sparse_categorical_accuracy: 0.9606 - loss: 0.1000


<keras.src.callbacks.history.History at 0x7b080fbeae50>

# Summary:

- Keras offers a spectrum of different workflows. *Progressive disclosure of complexity*. All interoperate together.
- Can build models using `Sequential`, Functional API, or subclassing `Model` class.
- Simplest way to train and evaluate model is via the default `fit()` and `evaluate()` methods.
- Callbacks are a way to monitor models during call to `fit()` and automatically take actions based on the state of the model.
- Can take full control of `fit()` by overriding `train_step()` method.
- Can write our own training loops entirely from scratch. Useful for implementing brand-new algorithms.