# Training, evaluation, and inference

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# data must be either NumPy arrays or Dataset objects
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

inputs = tf.keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


In [None]:
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    # List of metrics to monitor
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

history = model.fit(x_train, y_train, batch_size=64, epochs=2,
    validation_split=0.2,   # use 20% of the data for validation
)

history.history

In [None]:
results = model.evaluate(x_test, y_test, verbose=2)
# print("Test loss:", results[0])
# print("Test accuracy:", results[1])
print("test loss, test acc:", results)

In [None]:
print("Generate predictions for 3 samples")
predictions = model.predict(x_test[:3])
print("predictions shape:", predictions.shape)

## Compile

#### Built-in optimizers:
- SGD() (with or without momentum)
- RMSprop()
- Adam()

#### Built-in losses:
- MeanSquaredError()
- KLDivergence()
- CosineSimilarity()

#### Built-in metrics:
- AUC()
- Precision()
- Recall()

In [None]:
# if you're satisfied with the default settings, you can use string shortcuts:
model.compile(
    optimizer="rmsprop",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)

### Custom loss
The first method involves creating a function that accepts inputs y_true and y_pred.

In [None]:
def custom_mean_squared_error(y_true, y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred))

model.compile(
    optimizer=tf.keras.optimizers.Adam(), 
    loss=custom_mean_squared_error)

Second method : If you need a loss function that takes in parameters beside y_true and y_pred, you can **subclass the tf.keras.losses.Loss class** and implement the following two methods:

    __init__(self): 
        accept parameters to pass during the call of your loss function
    call(self, y_true, y_pred): 
        use the targets (y_true) and the model predictions (y_pred) to compute the model's loss


In [None]:
class CustomMSE(tf.keras.losses.Loss):
    def __init__(self, regularization_factor=0.1, name="custom_mse"):
        super().__init__(name=name)
        self.regularization_factor = regularization_factor

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
        return mse + reg * self.regularization_factor

model.compile(
    optimizer=tf.keras.optimizers.Adam(), 
    loss=CustomMSE())

### Custom metrics
You can easily create custom metrics by subclassing the tf.keras.metrics.Metric class. You will need to implement 4 methods:

    __init__(self):
        here you will create state variables for your metric.
    update_state(self, y_true, y_pred, sample_weight=None): 
        uses the targets and the model predictions to update the state variables.
    result(self): 
        uses the state variables to compute the final results.
    reset_state(self): 
        reinitializes the state of the metric.

State update and results computation are kept separate because in some cases, the results computation might be very expensive and would only be done periodically.

In [None]:
class CategoricalTruePositives(tf.keras.metrics.Metric):
    def __init__(self, name="categorical_true_positives", **kwargs):
        super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name="ctp", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
        values = tf.cast(values, "float32")
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.true_positives

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.true_positives.assign(0.0)

model.compile(
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[CategoricalTruePositives()],
)


### losses and metrics that don't fit the standard signature
For instance, a regularization loss may only require the activation of a layer (there are no targets in this case), and this activation may not be a model output.

In such cases, you can call self.add_loss(loss_value) from inside the call method of a custom layer. Losses added in this way get added to the "main" loss during training (the one passed to compile()). 

Same goes for logging metric values.

In [None]:
# note that activity regularization is built-in in all Keras layers 
# (this is just an example)
class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(tf.reduce_sum(inputs) * 0.1)
        return inputs  # Pass-through layer.

class MetricLoggingLayer(layers.Layer):
    def call(self, inputs):
        self.add_metric(
            tf.keras.backend.std(inputs), 
            name="std_of_activation", 
            aggregation="mean", # how to aggregate the per-batch values over each epoch
        )
        return inputs  # Pass-through layer.

inputs = tf.keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = ActivityRegularizationLayer()(x)
x = MetricLoggingLayer()(x)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)

# another example
class LogisticEndpoint(tf.keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = tf.keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, sample_weights=None):
        loss = self.loss_fn(targets, logits, sample_weights)
        self.add_loss(loss)

        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(logits)

inputs = tf.keras.Input(shape=(3,), name="inputs")
targets = tf.keras.Input(shape=(10,), name="targets")
logits = tf.keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)
model = tf.keras.Model(inputs=[inputs, targets], outputs=predictions)

model.compile(optimizer="adam")

In the Functional API, you can also call 

    model.add_loss(loss_tensor)
    model.add_metric(metric_tensor, name, aggregation)

In [None]:
inputs = tf.keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x2 = layers.Dense(64, activation="relu", name="dense_2")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.add_loss(tf.reduce_sum(x1) * 0.1)
model.add_metric(tf.keras.backend.std(x1), name="std_of_activation", aggregation="mean")

model.compile(optimizer=tf.keras.optimizers.RMSprop(1e-3))

## Fit

Note that you can only use `validation_split` when training with **NumPy** data. 

In all other cases you need to prepare th validation dataset in advance.

In [None]:
# split data manually
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

history = model.fit(x_train, y_train, batch_size=64, epochs=2,
    validation_data=(x_val, y_val),
)

In [None]:
# train on Dataset objects
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

# the dataset already takes care of batching
model.fit(train_dataset, epochs=3)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)
model.fit(train_dataset, epochs=1, validation_data=val_dataset)

Note that the Dataset is reset at the end of each epoch, so it can be reused of the next epoch.

If you want to run training only on a specific number of batches from this Dataset, you can pass the `steps_per_epoch` argument, which specifies how many training steps the model should run using this Dataset before moving on to the next epoch.

If you do this, the dataset is not reset at the end of each epoch, instead we just keep drawing the next batches. The dataset will eventually run out of data (unless it is an infinitely-looping dataset).

If you want to run validation only on a specific number of batches from this dataset, you can pass the `validation_steps` argument, which specifies how many validation steps the model should run with the validation dataset before interrupting validation and moving on to the next epoch.

Note that the validation dataset will be reset after each use (so that you will always be evaluating on the same samples from epoch to epoch).

In [None]:
# Only run 100 training steps before moving to the next epoch
model.fit(train_dataset, epochs=3, steps_per_epoch=100)

# Only run validation using the first 10 batches of the dataset
model.fit(train_dataset, epochs=1, validation_data=val_dataset,
    validation_steps=10,
)

If you have large datasets and you need to do a lot of custom Python-side processing that cannot be done in TensorFlow (e.g. if you rely on external libraries for data loading or preprocessing), you can use `tf.keras.utils.Sequence` class.

### Sample weighting and class weighting
With the default settings the weight of a sample is decided by its frequency in the dataset. There are two methods to weight the data, independent of sample frequency:

- Class weights
- Sample weights


In [None]:
# for classifiers
class_weight = {
    0: 1.0,
    1: 1.0,
    2: 1.0,
    3: 1.0,
    4: 1.0,
    5: 2.0,  # make this class 2x more important
    6: 1.0,
    7: 1.0,
    8: 1.0,
    9: 1.0,
}

model.fit(x_train, y_train, batch_size=64, epochs=1,
    class_weight=class_weight)

# for NumPy data
import numpy as np
sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0

model.fit(x_train, y_train, batch_size=64, epochs=1,
    sample_weight=sample_weight)

# for Dataset data
train_dataset = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train, sample_weight)
)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model.fit(train_dataset, epochs=1)


## Multiple inputs/outputs
When compiling a model with multiple inputs/outputs, you can assign different losses to each output. <br>
You can also assign different weights to each loss -- to modulate their contribution to the total training loss.

If we only passed a single loss function to the model, the same loss function would be applied to every output.

It is recommended the use of explicit names and dicts if you have more than 2 outputs.

In [None]:
input1 = layers.Dense(1, name="priority")(x)
...
model = tf.keras.Model(inputs=[input1, input2], outputs=[output1, output2, output3])

model.compile(
    optimizer=tf.keras.optimizers.RMSprop(1e-3),
    loss=[
        tf.keras.losses.BinaryCrossentropy(from_logits=True),
        tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    ],
    loss_weights=[1.0, 0.2],
)

# or using names:
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(1e-3),
    loss={
        "priority": tf.keras.losses.BinaryCrossentropy(from_logits=True),
        "department": tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    },
    loss_weights={"priority": 1.0, "department": 0.2},
)

# 
train_dataset = tf.data.Dataset.from_tensor_slices(
    (
        {"img_input": img_data, "ts_input": ts_data},
        {"score_output": score_targets, "class_output": class_targets},
    )
)

# Fit: either a tuple of lists or a tuple of dictionaries
model.fit(
    [title_data, body_data, tags_data], 
    [priority_targets, dept_targets], 
    batch_size=32, epochs=1,
)

model.fit(
    {"title": title_data, "body": body_data, "tags": tags_data},
    {"priority": priority_targets, "department": dept_targets},
    epochs=2,
    batch_size=32,
)

# Dataset
train_dataset = tf.data.Dataset.from_tensor_slices(
    (
        {"img_input": img_data, "ts_input": ts_data},
        {"score_output": score_targets, "class_output": class_targets},
    )
)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model.fit(train_dataset, epochs=1)


same goes for metrics

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(1e-3),
    loss={
        "score_output": tf.keras.losses.MeanSquaredError(),
        "class_output": tf.keras.losses.CategoricalCrossentropy(),
    },
    metrics={
        "score_output": [
            tf.keras.metrics.MeanAbsolutePercentageError(),
            tf.keras.metrics.MeanAbsoluteError(),
        ],
        "class_output": [tf.keras.metrics.CategoricalAccuracy()],
    },
)

You could also choose not to compute a loss for certain outputs, if these outputs are meant for prediction but not for training.

In [None]:
# List loss version
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(1e-3),
    loss=[None, tf.keras.losses.CategoricalCrossentropy()],
)

# Or dict loss version
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(1e-3),
    loss={"class_output": tf.keras.losses.CategoricalCrossentropy()},
)

## Callbacks
Callbacks are objects that are called at different points during training (at the start of an epoch, at the end of a batch, at the end of an epoch, etc.). 

    BaseLogger: Callback that accumulates epoch averages of metrics.
    CSVLogger: Callback that streams epoch results to a CSV file.
    Callback: Abstract base class used to build new callbacks.
    CallbackList: Container abstracting a list of callbacks.
    EarlyStopping: Stop training when a monitored metric has stopped improving.
    History: Callback that records events into a History object.
    LambdaCallback: Callback for creating simple, custom callbacks on-the-fly.
    LearningRateScheduler: Learning rate scheduler.
    ModelCheckpoint: Callback to save the Keras model or model weights at some frequency.
    ProgbarLogger: Callback that prints metrics to stdout.
    ReduceLROnPlateau: Reduce learning rate when a metric has stopped improving.
    RemoteMonitor: Callback used to stream events to a server.
    TensorBoard: Enable visualizations for TensorBoard.
    TerminateOnNaN: Callback that terminates training when a NaN loss is encountered.

In [None]:
# restart training from the last saved state if training gets interrupted
import os
def make_or_restore_model():
    # Either restore the latest model, or create a fresh one
    # if there is no checkpoint available.
    checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print("Restoring from", latest_checkpoint)
        return tf.keras.models.load_model(latest_checkpoint)
    print("Creating a new model")
    return get_model()

callbacks = [
    # Stop training when `val_loss` is no longer improving 
    # (no better than 1e-2 less for at least 2 epochs)    
    tf.keras.callbacks.EarlyStopping(monitor="val_loss",
        min_delta=1e-2, patience=2, verbose=1),
    # overwrite checkpoint if the `val_loss` score has improved.
    tf.keras.callbacks.ModelCheckpoint(filepath="mymodel_{epoch}",
        save_best_only=True, monitor="val_loss", verbose=1)
]

model.fit(x_train, y_train, epochs=20, batch_size=64,
    callbacks=callbacks, validation_split=0.2)

### Learning decay schedule
The learning decay schedule could be static (fixed in advance, as a function of the current epoch or the current batch index), or dynamic (responding to the current behavior of the model, in particular the validation loss).

#### built-in schedules
    ExponentialDecay
    PiecewiseConstantDecay
    PolynomialDecay
    InverseTimeDecay

In [None]:
# static learning rate decay schedule
initial_learning_rate = 0.1
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)

optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_schedule)

The optimizer does not have access to validation metrics --> you can't use schedule objects for dynamic learning decay --> use callbacks (eg. `ReduceLROnPlateau` callback)

### TensorBoard
The best way to keep an eye on your model during training is to use TensorBoard -- a browser-based application that you can run locally:

    tensorboard --logdir=/full_path_to_your_logs


In [None]:
tf.keras.callbacks.TensorBoard(
    log_dir="/full_path_to_your_logs",
    histogram_freq=0,  # How often to log histogram visualizations
    embeddings_freq=0,  # How often to log embedding visualizations
    update_freq="epoch",
)  # How often to write logs (default: once per epoch)


### LambdaCallback
Arguments:

-    `on_epoch_begin` and `on_epoch_end` expect two positional arguments: `epoch`, `logs`
-    `on_batch_begin` and `on_batch_end` expect two positional arguments: `batch`, `logs`
-    `on_train_begin` and `on_train_end` expect one positional argument: `logs`


In [None]:
# Print the batch number at the beginning of every batch.
batch_print_callback = tf.keras.callbacks.LambdaCallback(
    on_batch_begin=lambda batch,logs: print(batch))

# Stream the epoch loss to a file in JSON format. The file content
# is not well-formed JSON but rather has a JSON object per line.
import json
json_log = open('loss_log.json', mode='wt', buffering=1)
json_logging_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
    on_train_end=lambda logs: json_log.close()
)

# Terminate some processes after having finished model training.
processes = ...
cleanup_callback = tf.keras.callbacks.LambdaCallback(
    on_train_end=lambda logs: [
        p.terminate() for p in processes if p.is_alive()])

model.fit(...,
          callbacks=[batch_print_callback,
                     json_logging_callback,
                     cleanup_callback])
                     
# could be an actual function instead of lambda

### Custom callback

In [None]:
class LossHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

## Customizing training
- override the training step function of the Model class. This is the function that is called by `fit()` for every batch of data. You will then be able to call `fit()` as usual and benefit from its convenient features, such as callbacks, built-in distribution support, or step fusing.
- write your own training loop from scratch, using the `GradientTape` (more control)

### Overriding the training step

In [None]:
class CustomModel(tf.keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight,
                regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct and compile an instance of CustomModel
inputs = tf.keras.Input(shape=(32,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)


You could also just skip passing a loss function in compile(), and instead do everything manually in train_step. Likewise for metrics.

In [None]:
loss_tracker = tf.keras.metrics.Mean(name="loss")
mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")


class CustomModel(tf.keras.Model):
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = tf.keras.losses.mean_squared_error(y, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch or at the start
        # of `evaluate()`. If you don't implement this property, you have 
        # to call `reset_states()` yourself at the time of your choosing.
        return [loss_tracker, mae_metric]


# Construct an instance of CustomModel
inputs = tf.keras.Input(shape=(32,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't passs a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)


### Customizing evaluation

In [None]:
class CustomModel(tf.keras.Model):
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        # Updates the metrics tracking the loss
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


inputs = tf.keras.Input(shape=(32,))
outputs = tf.keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)


### end-to-end GAN example

In [None]:
discriminator = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

latent_dim = 128
generator = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):   # override
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator 
        # (note that we should *not* update the weights of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}
    
batch_size = 64
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True),
)

gan.fit(dataset.take(100), epochs=1)


### Writing a training loop from scratch

In [None]:
import time

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # record the operations run during the forward pass
        with tf.GradientTape() as tape:
            # Run the forward pass of the model.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # retrieve the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run 1 step of gradient descent
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


### Speeding-up your training step
The default runtime in TensorFlow 2 is eager execution. As such, our training loop above executes eagerly.

This is great for debugging, but graph compilation has a definite performance advantage. Describing your computation as a static graph enables the framework to apply global performance optimizations.

You can compile into a static graph any function that takes tensors as input. Just add a `@tf.function` decorator on it. Same goes for evaluation.

In [None]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

for epoch in range(epochs):
    ...
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        if step % 200 == 0:
            ...
    ...
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


### Low-level handling of losses tracked by the model
Layers & models recursively track any losses created during the forward pass by layers that call self.add_loss(value). The resulting list of scalar loss values are available via the property model.losses at the end of the forward pass.

If you want to be using these loss components, you should sum them and add them to the main loss in your training step.

In [None]:
class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs

inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        # Add any extra losses created during the forward pass.
        loss_value += sum(model.losses)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

Training for the above GAN example:

In [None]:
# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0004)

# Instantiate a loss function.
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def train_step(real_images):
    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Decode them to fake images
    generated_images = generator(random_latent_vectors)
    # Combine them with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble labels discriminating real from fake images
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )
    # Add random noise to the labels - important trick!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # Train the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator(combined_images)
        d_loss = loss_fn(labels, predictions)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Assemble labels that say "all real images"
    misleading_labels = tf.zeros((batch_size, 1))

    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with tf.GradientTape() as tape:
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = loss_fn(misleading_labels, predictions)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
    return d_loss, g_loss, generated_images

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

epochs = 1  # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"

for epoch in range(epochs):
    print("\nStart epoch", epoch)

    for step, real_images in enumerate(dataset):
        # Train the discriminator & generator on one batch of real images.
        d_loss, g_loss, generated_images = train_step(real_images)

        # Logging.
        if step % 200 == 0:
            # Print metrics
            print("discriminator loss at step %d: %.2f" % (step, d_loss))
            print("adversarial loss at step %d: %.2f" % (step, g_loss))

            # Save one generated image
            img = tf.keras.preprocessing.image.array_to_img(
                generated_images[0] * 255.0, scale=False
            )
            img.save(os.path.join(save_dir, "generated_img" + str(step) + ".png"))

        # To limit execution time we stop after 10 steps.
        # Remove the lines below to actually train the model!
        if step > 10:
            break
