# Custom Training Loop using TensorFlow Keras: Distributed Training

In this notebook, we describe the steps for creating a custom training loop using TensorFlow Keras for **distributed training**. Specifically, we use the **synchronous data parallelism** technique, where a single model gets replicated on multiple devices or multiple machines (e.g., multiple GPUs). Each of the devices processes different batches of data, then they merge their results. Thus, different replicas of the model stay in sync after each batch they process. Synchronicity keeps the model convergence behavior identical to what we would see for single-device training.

### Single-Host, Multi-Device Training
For implementing the synchronous data parallelism-based distributed training, we use the tf.distribute.MirroredStrategy API to train TensorFlow Keras models on multiple GPUs installed on a single machine.

    -- The specification for the number and type of GPUs should be provided in the SLURM .sh job request file. 


### How does the tf.distribute.MirroredStrategy Strategy Work?

- All the variables and the model graph are replicated across the replicas (e.g., GPUs).
- Input data is evenly distributed across the replicas.
- Each replica calculates the loss and gradients for the input it received.
- The gradients are synced across all the replicas by summing them.
- After the sync, the same update is made to the copies of the variables on each replica.

### A Summary of the Steps for Single-host, Multi-device Synchronous Training

- Instantiate a MirroredStrategy. By default, the strategy will use all GPUs available on a single machine.

- Create a tf.data.Dataset which is required to load data in a multi-device or distributed workflow. Use tf.distribute.Strategy.experimental_distribute_dataset to convert the tf.data.Dataset to something that produces "per-replica" values. 

- Use the strategy object to open a scope, and within this scope, create all the Keras objects that contain variables. These include the model, optimizer, checkpoint, loss functions, and metrics. 

- Define functions to train and validate the model using a single batch of data. Then, use the tf.distribute.Strategy.run method of the mirrored strategy to run these functions once per replica, taking "per-replica" values (e.g., from a tf.distribute.DistributedDataset object) and returning "per-replica" values.  The "run" method is executed in the "replica context", which means each operation is performed separately on each replica.

- Finally use a method (such as tf.distribute.Strategy.reduce) to convert the resulting "per-replica" values into ordinary Tensors.


More information on distributed training: https://keras.io/guides/distributed_training/

More information on distributed training for custom training loop: https://www.tensorflow.org/tutorials/distribute/custom_training



## Summary of the Other Techniques/Tools Used in this Notebook

In addition to describing the design of a custom training loop for distributed training, we use the following techniques/tools that are useful in practical deep learning tasks.

- Store and pre-process the data in the TensorFlow Dataset format.

- Build a model by defining a custom layer.

- Utilize various optimizers and schedulers.

- Model serialization.

        -- Serialize the final model (with custom layers) in TensorFlow's SavedModel format. 
        -- Serialize the intermediate model checkpoints (only the parameters of the model) using the tf.train.Checkpoint class. Customize the checkpoint object.

- Loading the saved model.
        
        -- Final SavedModel
        -- Intermediate model checkpoints

- Monitor the Training Process

        -- Use comet ml for monitoring the training in real time.

In [None]:
# Import comet_ml at the top of your file.
from comet_ml import Experiment

# Create an experiment with your api key.
experiment = Experiment(
    api_key="",
    project_name="",
    workspace="",
    auto_histogram_weight_logging=True,
    auto_histogram_gradient_logging=True,
    auto_histogram_activation_logging=True
)

In [None]:
# Import python libraries.
import os
import os
import time
import numpy as np

from sklearn.model_selection import train_test_split

import torch # torch is used to get GPU information.
import tensorflow as tf
import tensorflow_addons as tfa # required for some optimizers, e.g., AdamW, LAMB.

# Display the TensorFlow version.
print("\nTensorFlow Version: " , tf.__version__)

In [None]:
# Variable to store the number of available GPUs.
num_of_gpu = 0

# Determine the number of GPUs available.
if torch.cuda.is_available():    
    # Tell torch to use the GPU.    
    device = torch.device("cuda")
    
    # Get the number of GPUs.
    num_of_gpu = torch.cuda.device_count()
    print("Number of available GPU(s) %d." % num_of_gpu)

    print("GPU Name: ", torch.cuda.get_device_name(0))
else:
    print("No GPU available, using the CPU")
    device = torch.device("cpu")

In [None]:
'''
Create a MirroredStrategy.
'''
mirrored_strategy = tf.distribute.MirroredStrategy()
print("\nNumber of GPU(s): {}".format(mirrored_strategy.num_replicas_in_sync))

## Load & Scale the Dataset

In [None]:
'''
Load the dataset (training & test).
'''
(X_train_all, y_train_all), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()


# Convert datatype to float32.
X_train_all = X_train_all.astype('float32')
X_test = X_test.astype('float32')

# Create a validation subset.
X_train, X_val, y_train, y_val = train_test_split(X_train_all, y_train_all, test_size=0.15, random_state=42)

# Scale the data.
X_train = X_train/255.0
X_val = X_val/255.0
X_test = X_test/255.0

print("No. of Training Samples: ", X_train.shape)
print("No. of Training Labels: ", y_train.shape)

print("\nNo. of Validation Samples: ", X_val.shape)
print("No. of Validation Labels: ", y_val.shape)

print("\nNo. of Testing Samples: ", X_test.shape)
print("No. of Testing Labels: ", y_test.shape)

print("\nX type: ", X_train.dtype)
print("y type: ", y_train.dtype)


# Create one-hot encoded labels.
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_val = tf.keras.utils.to_categorical(y_val, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

print("\ny train (shape): ", y_train.shape)
print("y val (shape): ", y_val.shape)
print("y test (shape): ", y_test.shape)

## Create TensorFlow Dataset Objects

In [None]:
print("\n########################## Create TensorFlow Dataset objects ##########################\n\n")

'''
For loading the data dynamically and pre-processing (if required), 
we use TensorFlow's Data API (tf.data).
Specifically, from the NumPy arrays (feature & label), 
we construct TensorFlow Dataset (TensorSliceDataset) objects.
'''
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

train_dataset_size = len(list(train_dataset.as_numpy_iterator()))
print("\nTrain Dataset (shape of a single sample and its type):\n", train_dataset)
print("Train Dataset (size): ", train_dataset_size)

val_dataset_size = len(list(val_dataset.as_numpy_iterator()))
print("\nValidation Dataset (shape of a single sample and its type):\n", val_dataset)
print("Validation Dataset (size): ", val_dataset_size)

test_dataset_size = len(list(test_dataset.as_numpy_iterator()))
print("\nTest Dataset (shape of a single sample and its type):\n", test_dataset)
print("Test Dataset (size): ", test_dataset_size)

## Dataset Object Pre-Processing

In [None]:
'''
Function to shuffle & batch the elements of the Dataset object.

- The shuffle method randomly shuffles the elements of the Dataset object. 
First, it fills a buffer with the Dataset with buffer_size elements. 
Then, it randomly samples elements from this buffer, replacing the selected elements with new elements. 
For perfect shuffling, the buffer_size should be greater than or equal to the size of the Dataset. 
However, for large Dataset objects, this isn't possible. So, we will use a large enough buffer_size.


- In the batch method, we set "drop_remainder" to True so that the size of the training set is 
   divisible by the batch_size. It is done by removing enough training examples.

'''
def prepare_dataset(ds, mini_batch, repeat=1, shuffle=False, buffer_size=0):
    '''
    Cache the Dataset elements in memory
    '''
    ds = ds.cache()
  
    '''
    Shuffle the elements of Dataset
    '''
    if shuffle:
        ds = ds.shuffle(buffer_size)
      
    '''
    Repeat the elements of the shuffled Dataset
    '''  
    ds = ds.repeat(count=repeat)

    '''
    Batch the elements of the Dataset
    '''
    ds = ds.batch(mini_batch, drop_remainder=True)
    
    '''
    Use buffered prefecting on all elements of the Dataset
    '''
    return ds.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:

print("\n########################## Dataset Object Pre-Processing ##########################\n\n")

'''
Determine the GLOBAL_BATCH_SIZE (the total size of mini-batch for all GPUs) for training.
It should be a multiple of BATCH_SIZE_PER_REPLICA .
The multiplication factor is determined by the number of available GPUs (or strategy.num_replicas_in_sync)
'''

BATCH_SIZE_PER_REPLICA = 128

if(num_of_gpu > 0):
    GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA*num_of_gpu
else:
    GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA # Uses the CPU, as no GPU is available


'''
Size of the validation mini-batch. 
'''
size_of_mini_batch_val = GLOBAL_BATCH_SIZE

print("\nSize of training mini-batch: ", GLOBAL_BATCH_SIZE)
print("Size of validation mini-batch: ", size_of_mini_batch_val)


'''
The "buffer_size" variable is used by the "shuffle" method. 
For a small Dataset object, it should be equal to or larger than training set.
'''
buffer_size = train_dataset_size
print("\nBuffer Size: ", buffer_size)

'''
Set the number of training epochs.
This is required by the repeat method.
'''
no_of_epochs = 100
print("\nEpochs: ", no_of_epochs)

'''
Perform data pre-processing by the CPU.
It is efficient to use the CPU as it ensures that the GPUs will be used only for model training.
'''
with tf.device('/cpu:0'):
    train_loader = prepare_dataset(train_dataset, GLOBAL_BATCH_SIZE, repeat=1, 
                                   shuffle=True, buffer_size=buffer_size)
    val_loader = prepare_dataset(val_dataset, size_of_mini_batch_val)
    test_loader = prepare_dataset(test_dataset, size_of_mini_batch_val)
    no_of_steps_per_epoch = train_dataset_size//GLOBAL_BATCH_SIZE
    print("Number steps/epoch: ", no_of_steps_per_epoch)
    total_no_of_steps = train_loader.cardinality().numpy()
    print("Data available to run for %d epochs" % (total_no_of_steps//no_of_steps_per_epoch))
    print("Unlimited data available: ", (train_loader.cardinality() == tf.data.INFINITE_CARDINALITY).numpy())

## Distribute the Dataset Objects Over the GPUs

In [None]:
'''
Distribute the Dataset objects based on the mirrored strategy by using the 
tf.distribute.Strategy.experimental_distribute_dataset method.

This method converts a tf.data.Dataset to something that produces "per-replica" values. 

Alternatively, we can manually specify how the dataset should be partitioned across replicas. 
For this, we may use the tf.distribute.Strategy.distribute_datasets_from_function.
'''
dist_train_loader = mirrored_strategy.experimental_distribute_dataset(train_loader)
dist_val_loader = mirrored_strategy.experimental_distribute_dataset(val_loader)

## Create the Model

We build a VGGNet model by defining a custom layer (for the VGG block). 

- First, define a VGG block (layer) class using the subclass of the tf.keras.layers.Layer class. This custom layer uses a get_config method for serialization.
- Then, define the model function by flexibly utilizing the VGG block class.


### Optimizers and Schedulers

We provide options for utilizing various optimizers and schedulers.

- Optimizers: 

        -- SGD (Stochastic Gradient Descent)
        -- Adam (Adaptive moment estimation)
        -- Nadam (Adam with Nesterov momentum)
        -- AdamW (a variant of Adam where the weight decay is performed only after controlling the parameter-wise step size), useful for adapting the learning rate in small batch settings
        -- LAMB (Layer-wise Adaptive Moments optimizer for Batch training), useful for adapting the learning rate in large batch settings

- Schedulers for SGD:

        -- Exponential Decay
        -- Piecewise Constant Decay
        -- Cosine Decay Restarts

##### The model, the optimizer, and the checkpoint must be created under the mirrored strategy scope.


#### Save the Parameters (variables) of the Model Intermittently

After the training is complete, we serialize the final model in TensorFlow's SavedModel format. SavedModel is a comprehensive save format that saves the model architecture, parameters (weights), and the traced Tensorflow subgraphs of the call functions. 

However, during the training we save **only the parameters of the model at a regular interval**, e.g., after every fixed number of epochs. These objects are called model checkpoints. For this intermittent saving (and restoring), we use the tf.train.Checkpoint class. 

We customize the saving process of model checkpoints by using the tf.train.CheckpointManager class.

- To keep the only last few Checkpoints. E.g., we keep the last 5 checkpoints.
- To use the epoch number for numbering the saved checkpoints. By default, checkpoints are numbered from 1.

In [None]:
'''
Define the VGG block class using Keras Sequential API.
The VGG_Block class takes two arguments:
- conv_block_number: number of convolutional layers 
- num_of_channels: number of output channels 
'''
class VGG_Block(tf.keras.layers.Layer):
    def __init__(self, conv_block_number, num_of_channels, weight_decay, **kwargs):
        super().__init__(**kwargs)
        self.conv_layers = [] 
        for _ in range(conv_block_number):
            self.conv_layers.append(tf.keras.layers.Conv2D(filters=num_of_channels, kernel_size=(3, 3), strides=1,
                                padding='same', kernel_regularizer=tf.keras.regularizers.l2(weight_decay), use_bias=False))
            self.conv_layers.append(tf.keras.layers.BatchNormalization())
            self.conv_layers.append(tf.keras.layers.Activation("relu"))
        
        self.pool_layer = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid')

    def call(self, inputs):
        Z = inputs
        for layer in self.conv_layers:
            Z = layer(Z)
            
        Z = self.pool_layer(Z)
        return Z
        
    # Required for the custom object's serialization.
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "conv_layers": self.conv_layers,
            "pool_layer": self.pool_layer,
        })
        return config


'''
Function to create the VGGNet model.
'''
def create_model(conv_blocks, width, height, channels, num_classes, weight_decay, augmentation=False, **kwargs):
    
    vgg_net = tf.keras.models.Sequential(name='VGG')
    
    vgg_net.add(tf.keras.layers.InputLayer(input_shape=(width, height, channels)))
    
    # Data augmentation layer.
    if(augmentation):
        vgg_net.add(data_augmentation_layer(**kwargs))
    
    # Convolutional layers based on the VGG_Block object.
    for (conv_block_number, num_of_channels) in conv_blocks:
            vgg_net.add(VGG_Block(conv_block_number, num_of_channels, weight_decay))
    
    # Flatten the convnet output to feed it with fully-connected layers.
    vgg_net.add(tf.keras.layers.Flatten())
    
    # Fully-connected layers
    vgg_net.add(tf.keras.layers.Dense(units=64, activation='relu'))
    vgg_net.add(tf.keras.layers.Dropout(0.5))
    vgg_net.add(tf.keras.layers.Dense(units=num_classes, activation='softmax'))
    
    return vgg_net

In [None]:
print("\n########################## Create the Model ##########################\n\n")

'''
Reset all state generated by Keras.
It deletes the TensorFlow graph before creating a new model, 
otherwise memory overflow will occur.
'''
tf.keras.backend.clear_session()

'''
To reproduce the same result by the model in each iteration, we use fixed seeds for random number generation. 
'''
np.random.seed(42)
tf.random.set_seed(42)


'''
Create the model and optimizer inside the strategy's scope. 
This ensures that any variables created with the model and optimizer are mirrored variables.
'''
with mirrored_strategy.scope():
    
    '''
    Instantiate the model.
    '''
    layer_info = ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512)) # for the VGGNet model
    model = create_model(layer_info, 32, 32, channels=3, num_classes=10,
                           weight_decay=0.001)
    
    '''
    Display a summary of the model architecture.
    '''
    model.summary()
    
    '''
    Instantiate a learning rate scheduler for the SGD optimizer.
    There are 3 choices. 
    - Exponential decay
    - Piecewise constant decay
    - Cosine decay with restarts

    NOTE: Uncomment only one scheduler if the SGD optimizer is used.
    '''

    '''
    The initial learning rate is used by the optimizers, e.g., SGD, ADAM, NADAM, etc.
    Some SGD schedulers also require an initial learning rate (e.g., exponential decay, cosine decay).
    '''
    initial_learning_rate=0.01

    '''
    Scheduler: ExponentialDecay
    '''
    # decay_steps=no_of_steps_per_epoch * 50
    # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    #     initial_learning_rate=initial_learning_rate,
    #     decay_steps=decay_steps,
    #     decay_rate=0.1,
    #     staircase=True)


    '''
    Scheduler: PiecewiseConstantDecay
    '''
    # boundaries = [no_of_steps_per_epoch * 150, no_of_steps_per_epoch*250]
    # values = [0.5, 0.1, 0.01]
    # lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

    '''
    Scheduler: CosineDecayRestarts
    '''
    first_decay_steps = no_of_steps_per_epoch * 20
    lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(initial_learning_rate,
                                                                    first_decay_steps)

    '''
    Instantiate an optimizer. Use one of the following choices.
    - Fixed LR: learning_rate=linitial_learning_rate
    - Scheduled LR: learning_rate=lr_schedule
    '''
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9, nesterov=False)
    #optimizer=tf.keras.optimizers.Nadam(learning_rate=initial_learning_rate)
    #optimizer=tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)
    #optimizer=tfa.optimizers.AdamW(learning_rate=initial_learning_rate, weight_decay=0.001)
    #optimizer=tfa.optimizers.LAMB(learning_rate=initial_learning_rate, weight_decay_rate=0.001)
    
    
    '''
    The checkpoint model name variable is used for serializing model checkpoints
    '''
    checkpoint_model_name='Model-Distributed-Checkpoint'

    '''
    Path to the diretory in which the model checkpoints will be serialized
    '''
    checkpoint_model_directory='./Checkpoint_Models/'

    '''
    Instantiate the checkpoint function
    '''
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

    '''
    Instantiate the checkpoint manager function
    '''
    manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_model_directory, 
                                         checkpoint_name=checkpoint_model_name, max_to_keep=5)

## Utility Functions for Computing Loss and Accuracy

In [None]:
print("\n\n######### Functions: Utility Functions for Loss & Accuracy ###########\n\n")

with mirrored_strategy.scope():
    '''
    Define two functions to compute the training loss (per iteration/step).
    
    Function "loss_object"
    Set reduction to `NONE` so the reduction is done afterwards and divide by global batch size.
    The value `SUM_OVER_BATCH_SIZE` is disallowed because currently it would only divide by per replica batch size.
    
    Function "compute_loss"
    We should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE. 
    For this, we use tf.nn.compute_average_loss which takes the per example loss, 
    and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.
    '''
    loss_object=tf.keras.losses.CategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE)

    
    def compute_loss(labels, predictions, model_losses):
        per_example_loss = loss_object(labels, predictions)
        loss = tf.nn.compute_average_loss(per_example_loss,
                                          global_batch_size=GLOBAL_BATCH_SIZE)
        
        '''
        If there are additional losses incurred by the model (e.g., weight regularizer loss), 
        then we should sum them up and divide the sum by the number of replicas. 
        This is accomplished by using the tf.nn.scale_regularization_loss function.
        This function scales the sum of the given regularization losses by number of replicas.
        '''
        if model_losses:
            loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
    
        return loss
    
    # Instantiate a function to compute mean validation loss (loss per epoch).
    val_loss_epoch = tf.keras.metrics.Mean(name='val_loss')

    # Instantiate functions to compute train & val accuracies per epoch.   
    train_acc_epoch = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
    val_acc_epoch = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')
    

## Utility Functions for Displaying the Training Status

In [None]:
print("\n\n######### Functions: Utility Functions for Displaying Training Status ###########\n\n")

# Function to display training statistics per step 
def print_status_per_step(iteration, total, train_loss):
    metrics = " - ".join(["{}: {:.4f}".format("loss", train_loss)])
    end = "" if iteration < total else "\n"
    print("\r{}/{} - ".format(iteration, total) + metrics,
          end=end)
    
# Function to display training statistics per epoch
def print_status_per_epoch(iteration, total, train_loss, train_acc,
                          val_loss, val_acc, lr, time_per_epoch):
    metrics = " - ".join(["{}: {:.4f} - {}: {:.4f} - {}: {:.4f} - {}: {:.4f} - {}: {:.4f} {:.2f}s"\
                          .format("loss", train_loss,
                                  "acc", train_acc,
                                  "val loss", val_loss,
                                  "val acc", val_acc,
                                  "lr", lr,
                                  time_per_epoch),
                         ])
    end = "" if iteration < total else "\n"
    print("\r{}/{} ".format(iteration, total) + metrics,
          end=end)

## Functions for Training & Validation of the Model For Each Batch Per-Replica

We define "per-replica" two functions (train_step & val_step) for training and validation for each batch of data (during an epoch). 

        -- These functions are executed on each replica.

The train_step function performs the following tasks.

- Define the tf.GradientTape() block. 

- Inside the block, make a prediction for one batch (using the model as a function), and compute the loss. The loss consists of the main loss plus the other losses (e.g., weight regularizer loss). Note that, to save memory, we only put the strict minimum operations inside the tf.GradientTape() block. The tape is automatically erased immediately after we call its gradient() method.

- Ask the tape to compute the gradient of the loss with respect to each trainable variable (by using the gradient() method).

- Apply them to the optimizer to perform a Gradient Descent step ((by using the apply_gradient() method)).

- Update the mean loss and the metrics (over the current epoch).


In [None]:
'''
Functions for defining train & val step functions.
'''
def train_step(X, y):
        '''
        Open a GradientTape to record the operations run
          during the forward pass. 
        This will enable auto-differentiation for computing the loss gradient.

        Run the forward pass of the layer.
           Trainable variables are automatically tracked by GradientTape, i.e., recorded
             on the GradientTape.
        '''
        with tf.GradientTape() as tape:
            # Predict the logits for the current minibatch.
            y_pred = model(X, training=True) 
            
            # Compute the loss for the current minibatch.
            loss = compute_loss(y, y_pred, model.losses)
            
        '''
        Compute the gradient of the loss function wrt the trainable variables (weights & biases).
        Use the GradieneTape object to automatically retrieve
          the gradients of the trainable variables with respect to the loss.
        '''
        gradients = tape.gradient(loss, model.trainable_variables)
        
        '''
        Update the weights by using the loss gradients.
        Run one iteration/step of gradient descent by updating
          the value of the variables to minimize the loss.
        '''
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        
        '''
        Update the training accuracy metric for the current epoch.
        This method gets executed by tf.distribute.Strategy.run in the distributed_train_step function.
        '''
        train_acc_epoch.update_state(y, y_pred)
        
        return loss
        
        
def val_step(X_val, y_val):
    # Predict the logits for the current validation minibatch.
    y_val_pred = model(X_val, training=False)
        
    # Compute the loss for the current validation minibatch.
    loss_val = loss_object(y_val, y_val_pred)
    
    '''
    Compute the validation loss for the current epoch.
    This method gets executed by tf.distribute.Strategy.run in the distributed_val_step function.
    '''
    val_loss_epoch.update_state(loss_val)
        
    '''
    Compute the validation accuracy for the current epoch.
    This method gets executed by tf.distribute.Strategy.run in the distributed_val_step function.
    '''
    val_acc_epoch.update_state(y_val, y_val_pred)

## Functions for Distributed Training & Validation of the Model For Each Batch 

The key idea is to replicate the train_step and val_step functions over the GPUs and run it
with the distributed input. This is accomplished by the mirrored strategy "run" method.
The tf.distribute.Strategy.run returns results from each local replica in the strategy. 
Then, we use tf.distribute.Strategy.reduce to get an aggregated value. 

To compile the functions into a static graph, add a @tf.function decorator
for both the distributed_train_step and distributed_val_step functions.

Describing the computation as a static graph enables the framework to apply global performance optimizations. 
This is impossible when the framework is constrained to greedly execute one operation after another, 
with no knowledge of what comes next.

In [None]:
print("\n\n########### Functions for Distributed Training & Validation of the Model For Each Batch #############\n\n")

@tf.function
def distributed_train_step(X, y):
    per_replica_losses = mirrored_strategy.run(train_step, args=(X, y,))
    
    return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_val_step(X, y):
    return mirrored_strategy.run(val_step, args=(X, y,))

## Train the Model

The distributed_train_step and distributed_val_step functions are used to define a custom training loop.

The following tasks are performed.

- Create two nested loops: one for the epochs, and the other for the batches within an epoch.

      -- Within the inner loop (for iterating through the batches within an epoch), call the distributed_train_step function. It will train the model on all replicas using distributed batches of data within an epoch.

      -- Display the status bar to show the training statistics for each iteration and/or epoch.

- The distributed_val_step function is executed within the outer loop for each epoch.

In [None]:
print("\n\n########################## Train the Model ##########################\n\n")

# This variable keeps a count of the total number of iterations/steps for training until end (for all epochs).
iterations_total = 0

# Perform training and validation.
for epoch in range(1, no_of_epochs + 1):
    print("Epoch {}/{}".format(epoch, no_of_epochs))

    # Get the decayed learning rate per epoch at the beginning of an epoch.
    # This learning rate will be used by the "optimizer" in the train_step function.
    lr_epoch = optimizer._decayed_lr(tf.float32).numpy()
    
    total_loss = 0.0 # Stores the total loss during an epoch
    num_batches = 0  # Stores the number of batches during the training in an epoch

    # Get the start time for each epoch.
    start_time = time.time()

    # Capture the training statistics via the comel ml object.
    with experiment.train():
        # Iterate over the batches of the dataset.
        for step, (X_batch, y_batch) in enumerate(dist_train_loader):

            '''
            Compute the training loss per step.
            The scaled loss is the return value of the distributed_train_step. 
            This value is aggregated across replicas using the tf.distribute.Strategy.reduce call,
            and then across batches by summing the return value of the tf.distribute.Strategy.reduce calls.
            '''
            total_loss += distributed_train_step(X_batch, y_batch)

            # Increment the step count.
            iterations_total = iterations_total + 1
            
            # Increment the batch count within the current epoch
            num_batches = num_batches + 1
            
            # Compute the training loss (up to the current batch).
            train_loss = total_loss / num_batches
        
            # Display the training loss per step.
            print_status_per_step(step * GLOBAL_BATCH_SIZE, len(y_train), train_loss)
            
        
        # Compute the training accuracy at the end of each epoch.
        train_acc = train_acc_epoch.result()

        # Reset the training accuracy metric at the end of each epoch.
        train_acc_epoch.reset_states()

    # Capture the validation statistics via the comel ml object.
    with experiment.test():
        # Run a validation loop at the end of each epoch.
        for X_batch_val, y_batch_val in dist_val_loader:
            distributed_val_step(X_batch_val, y_batch_val)

        # Compute the mean val loss at the end of each epoch.
        val_loss = val_loss_epoch.result()

        # Compute the validation accuracy at the end of each epoch.
        val_acc = val_acc_epoch.result()

    # Compute the time taken for each epoch.
    time_per_epoch = time.time() - start_time

    # Display the training statistics at the end of each epoch.
    print_status_per_epoch(len(y_train), len(y_train), train_loss, train_acc,
                          val_loss, val_acc, lr_epoch, time_per_epoch)


    # Reset validation metrics at the end of each epoch.
    val_acc_epoch.reset_states()
    val_loss_epoch.reset_states()

    # Save checkpoints after every 10 epochs.
    if epoch % 10 == 0:
        # The epoch number is used to number the saved checkpoint.
        manager.save(checkpoint_number=epoch)

    # Define a set of metrics to be stored via the comet ml object.
    metrics = {
        'loss':train_loss,
        'accuracy':train_acc,
        'val_loss':val_loss,
        'val_accuracy':val_acc,
        'learning_rate':lr_epoch,
        'epoch':epoch,
        'iterations': iterations_total
    }
    # Log the metrics via the comet ml object.
    experiment.log_metrics(metrics)

# Log the following hyperparameters via the comet ml object.
params={'batch_size':GLOBAL_BATCH_SIZE,
        'epochs':no_of_epochs,
        'iterations': iterations_total,
        'optimizer':optimizer,
        'scheduler': lr_schedule
}
experiment.log_parameters(params)
experiment.end()

print("\n\nTraining completed successfully! :)\n")    


'''
Save the final model to disk in the SavedModel format.
SavedModel is a comprehensive save format that saves the model architecture, 
weights, and the traced Tensorflow subgraphs of the call functions. 
This enables Keras to restore both built-in layers as well as custom objects.
'''

print("\nSaving the fully trained model in the SavedModel format ... \n\n")

# The model name variable is used for model serialization.
final_model_name='Model-Distributed'

# Path to the diretory in which the FINAL model will be serialized.
final_model_directory='./Saved_Models/'

# Path name of the final model.
final_model_path = os.path.join(final_model_directory, final_model_name)


'''
Save the final model.

When saving the model to a local I/O device while training on remote devices (e.g., GPUs on a remote node),
we need to set the I/O device to localhost.
This is done by using the option experimental_io_device in tf.saved_model.SaveOptions.
'''
model.save(final_model_path)
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(final_model_path, options=save_options)

## Model Evaluation

In [None]:
print("\n###################### Model Evaluation using Test Data ######################\n\n")

# Predict test labels.
y_test_pred = model(X_test, training=False)

print("\nTest Data Predictions (shape): ", y_test_pred.shape)
print("Test Data (shape): ", y_test.shape)

# Compute the test accuracy.
acc_fn = tf.keras.metrics.CategoricalAccuracy()
acc_fn.update_state(y_test, y_test_pred)
test_accuracy = acc_fn.result().numpy()
print("\nTest Accuracy: ", test_accuracy)

# Compute the test loss.
test_loss = compute_loss(y_test, y_test_pred, model.losses).numpy()
print("Test Loss: ", test_loss)
print("\n\n")

## Model Evaluation (using the Saved Model)

In [None]:
print("\n\n######################### Model Evaluation (using the Saved Model) #########################\n\n")

'''
We evaluate the saved model using its predict() method on the test dataset.

The saved model (that was stored in the SavedModel format) can be loaded 
by using tf.keras.models.load_model() method. 

NOTE: If the model contains custom layers, 
then we need to set the "custom_objects" argument of the "load_model" method. 
It should be a dictionary mapping names (strings) to custom classes or functions to be 
considered during deserialization.
'''

print("\n\nLoading the saved model...\n\n")


'''
Load the final model.

When loading the model from a local I/O device that was trained on remote devices (e.g., GPUs on a remote node),
we need to set the I/O device to localhost.
This is done by using the option experimental_io_device in tf.saved_model.LoadOptions.
'''
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
    load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    saved_model = tf.keras.models.load_model(final_model_path, options=load_options,
                                            custom_objects={"vgg": create_model})


# Predict test labels (set the "training" argument false).
y_test_pred = saved_model(X_test, training=False)

print("\nTest Data Predictions (shape): ", y_test_pred.shape)
print("Test Data (shape): ", y_test.shape)

# Compute the test accuracy.
acc_fn = tf.keras.metrics.CategoricalAccuracy()
acc_fn.update_state(y_test, y_test_pred)
test_accuracy = acc_fn.result().numpy()
print("\nTest Accuracy: ", test_accuracy)

# Compute the test loss.
test_loss = compute_loss(y_test, y_test_pred, model.losses).numpy()
print("Test Loss: ", test_loss)
print("\n\n")

## Model Evaluation (using the Model Checkpoint)

In [None]:
print("\n\nLoading the model checkpoint...\n")

# Instantiate the model.
layer_info = ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512)) # for the VGGNet model
new_model = create_model(layer_info, 32, 32, channels=3, num_classes=10,
                           weight_decay=0.001)

# Instantiate the Checkpoint and specify the model instance to be restored.
checkpoint = tf.train.Checkpoint(model=new_model)


'''
Restore the parameter values of the model instance.

NOTE: Use the function expect_partial() on the loaded status, 
since model saved from Keras often generates extra keys in the checkpoint. 
Otherwise, the program prints a lot of warnings about unused keys at exit time.
'''
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_model_directory)).expect_partial()


# Predict the test labels (set the "training" argument false) using the test NumPy test dataset.
y_test_pred = new_model(X_test, training=False)


# Compute the test accuracy.
acc_fn = tf.keras.metrics.CategoricalAccuracy()
acc_fn.update_state(y_test, y_test_pred)
test_accuracy = acc_fn.result().numpy()
print("\nTest Accuracy (using NumPy test dataset): ", test_accuracy)

# Compute the test loss.
test_loss = compute_loss(y_test, y_test_pred, new_model.losses).numpy()
print("Test Loss (using NumPy test dataset): ", test_loss)
print("\n")


'''
Alternatively, we can use the testDataset object to compute test accuracy & loss.
First, we define a function for computing test accuracy & loss.
'''

@tf.function
def test_step(X, y):
    # Predict the labels.
    y_pred = new_model(X, training=False)
    # Compute accuracy.
    accuracy = acc_fn(y, y_pred)
    # Compute loss.
    loss = compute_loss(y, y_pred, model.losses)
    
    # Return accuracy & loss
    return accuracy, loss
    

'''
Compute test accuracy & loss using the test Dataset object.
'''
for X, y in test_loader:
    test_accuracy, test_loss = test_step(X, y)
    
    
print("Test Accuracy (using test Dataset object): ", test_accuracy.numpy())
print("Test Loss (using test Datsset object): ", test_loss.numpy())
print("\n\n")