<a href="https://colab.research.google.com/github/sourcecode369/TensorFlow-2.0/blob/master/tensorflow_2.0_docs/TensorFlow%20Core/Tutorials/Distributed%20Training/Custom%20Training%20Loop/TensorFlow_2_0_Custom_Training_tf_distribute_Strategy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Installing TensorFlow 2.0

In [23]:
!pip install --upgrade tensorflow-gpu

Requirement already up-to-date: tensorflow-gpu in /usr/local/lib/python3.6/dist-packages (2.0.0)


### Importing the libraries

In [0]:
from __future__ import absolute_import, print_function, unicode_literals, division

import tensorflow as tf

import numpy as np
import os

### Setting up TensorFlow

In [25]:
print("TensorFlow version: ",tf.__version__)
print("GPU is","available" if tf.test.is_gpu_available() else "unavailable.")
print("TensorFlow is executing eagerly: ",tf.executing_eagerly())
print("Random seeds set.")
tf.random.set_seed(1)
print("Setting up TensorFlow soft device placement and device placement.")
tf.config.set_soft_device_placement(True)
tf.debugging.set_log_device_placement(True)
%load_ext tensorboard

TensorFlow version:  2.0.0
GPU is available
TensorFlow is executing eagerly:  True
Random seeds set.
Setting up TensorFlow soft device placement and device placement.
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


### Setting up the data

In [0]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

train_images = train_images[..., None]
test_images = test_images[..., None]

train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

In [27]:
print(train_images.shape)
print(test_images.shape)

(60000, 28, 28, 1)
(10000, 28, 28, 1)


### Defining the tf.distribute.Strategy as Mirrorerd Strategy

In [28]:
strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
print(f"Number of parallel devices: {strategy.num_replicas_in_sync}")

Number of parallel devices: 1


In [0]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10

### tf.data and strategy.experimental_distribute_dataset for setting up the dataset

In [30]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Equal in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op LogicalAnd in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Select in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op AnonymousRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ShuffleDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RebatchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AutoShardDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op DeleteRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0


In [31]:
print(train_dist_dataset.element_spec)
print(test_dist_dataset.element_spec)

(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.uint8, name=None))
(TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.uint8, name=None))


### Defining the model

In [0]:
def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

  return model

In [0]:
checkpoint_dir = 'training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_")

### Defining the loss function

In [0]:
with strategy.scope():
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

### Define the metrics to track loss and accuracy

In [0]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

### Training Loop

In [60]:
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  # defining a checkpoint object
  checkpoint = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, model=model)
  
  # defining a checkpoint manager
  manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts', max_to_keep=3)

Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:CPU:0


In [0]:
with strategy.scope():
  
  def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_accuracy.update_state(labels, predictions)
    return loss 

  def test_step(inputs):
    images, labels = inputs
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)
    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

In [64]:
with strategy.scope():
    @tf.function
    def distributed_train_step(dataset_inputs):
        per_replica_losses = strategy.experimental_run_v2(train_step,
                                                          args=(dataset_inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

    @tf.function
    def distributed_test_step(dataset_inputs):
        return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))    

    for epoch in range(EPOCHS):
        total_loss = 0.0
        num_batches = 0

        for x in train_dist_dataset:
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        for x in test_dist_dataset:
            distributed_test_step(x)

        # distributed strategies dont support operations on checkpoint variables
        checkpoint.step.assign_add(1)

        # printing out the saved checkpoint epoch
        if int(epoch)%2==0:
            save_path = manager.save()
            print("Saved checkpoint for step {}:{}".format(int(epoch),save_path))
        template = ("Epoch {}, Loss {}, Accuracy {}, Test Loss {}, Test Accuracy {}")
        print(template.format(epoch+1, 
                          train_loss, 
                          train_accuracy.result()*100,
                          test_loss.result(),
                          test_accuracy.result() * 100
                          ))
        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states

Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_distributed_train_step_358389 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_distributed_train_step_368860 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_distributed_test_step_369013 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_distributed_test_step_370488 in device /job:localhost/replica:0/task:0/device:GPU:0
Saved checkpoint for step 0:./tf_ckpts/ckpt-2
Epoch 1, Loss 0.2678895592689514, Accuracy 87.17208862304688, Test Loss 0.3373919725418091, Test Accuracy 88.44000244140625
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 2, Loss 0.2432761937379837, Accuracy 91.001663

In [0]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name = "eval_accuracy"
)

new_model = create_model()

new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [0]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [68]:
print(f"Final Checkpoints: {manager.checkpoints}")

Final Checkpoints: ['./tf_ckpts/ckpt-4', './tf_ckpts/ckpt-5', './tf_ckpts/ckpt-6']


In [73]:
# defining and restroring the latest checkpoint from the saved path through the predefined checkpoint manager
checkpoint = tf.train.Checkpoint(step=tf.Variable(1),optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(manager.latest_checkpoint))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {:.2f}'.format(
    eval_accuracy.result()*100))

Accuracy after restoring the saved model without strategy: 90.79


### Loading Mechanics

> TensorFlow provides various loading mechanics.
    
    * Delayed Restorations
    * Manually inspecting checkpoints
    * List and dictionary tracking

> TensorFlow matches variables to checkpointed values by traversing a directed graph with named edges, starting from the object being loaded. Edge names typically come from attribute names in objects, for example the "l1" in self.l1 = tf.keras.layers.Dense(5). tf.train.Checkpoint uses its keyword argument names, as in the "step" in tf.train.Checkpoint(step=...).

##### Manually inspecting Checkpoints

In [72]:
tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))

[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('model/layer-0/bias/.ATTRIBUTES/VARIABLE_VALUE', [32]),
 ('model/layer-0/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [32]),
 ('model/layer-0/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [32]),
 ('model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE', [3, 3, 1, 32]),
 ('model/layer-0/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [3, 3, 1, 32]),
 ('model/layer-0/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [3, 3, 1, 32]),
 ('model/layer-2/bias/.ATTRIBUTES/VARIABLE_VALUE', [64]),
 ('model/layer-2/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [64]),
 ('model/layer-2/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [64]),
 ('model/layer-2/kernel/.ATTRIBUTES/VARIABLE_VALUE', [3, 3, 32, 64]),
 ('model/layer-2/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [3, 3, 32, 64]),
 ('model/layer-2/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATT

### Alternate ways of iterating over a dataset

#### Using Iterators

> If we want to iterate over a given number of steps and not through the entire dataset we can create an iterator using the iter call and explicity call next on the iterator. <strong> We can choose to iterate over the dataset both inside and outside the tf.function</strong>. Here is a small snippet demonstrating iteration of the dataset outside the tf.function using an iterator.

In [76]:
with strategy.scope():
    for epoch in range(EPOCHS):
        total_loss = 0.0
        num_batches = 0
        train_iter = iter(train_dist_dataset)

        for _ in range(10):
            total_loss += distributed_train_step(next(train_iter))
            num_batches += 1
        average_train_loss = total_loss / num_batches
        template = ("Epoch {}, Loss {}, Accuracy {}")
        print(template.format((epoch+1), average_train_loss, train_accuracy.result()*100))
        train_accuracy.reset_states()

Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 1, Loss 0.08790183067321777, Accuracy 96.328125
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 2, Loss 0.1114124059677124, Accuracy 95.78125
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 3, Loss 0.1382780224084854, Accuracy 94.53125
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 4, Loss 0.11071242392063141, Accuracy 95.78125
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 5, Loss 0.11459411680698395, Accuracy 95.46875
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 6, Loss 0.08854411542415619, Accuracy 97.1875
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 7, Loss 0.1039050966501236, Accuracy 96.25
Executing op GeneratorDataset in device 

#### Iterating Inside a tf.function

> We can also iterate over the entire input train_dist_dataset inside a tf.function using the for x in ... construct or by creating iterators like we did above. The example below demonstrates wrapping one epoch of training in a tf.function and iterating over train_dist_dataset inside the function.

In [81]:
with strategy.scope():
    @tf.function
    def distributed_train_epoch(dataset):
        total_loss = 0.0
        num_batches = 0
        for x in dataset:
            per_replica_losses = strategy.experimental_run_v2(train_step, args=(x,))
            total_loss += strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
            num_batches += 1
        return total_loss / tf.cast(num_batches,tf.float32)
    
    for epoch in range(EPOCHS):
        train_loss = distributed_train_epoch(train_dist_dataset)
        template = ("Epoch {}, Loss {:.2f}, Accuracy {:.2f}")
        print(template.format((epoch+1), train_loss, train_accuracy.result()*100))
        train_accuracy.reset_states()

Executing op __inference_distributed_train_epoch_483235 in device /job:localhost/replica:0/task:0/device:GPU:0
Epoch 1, Loss 0.11, Accuracy 95.75
Epoch 2, Loss 0.10, Accuracy 96.17
Epoch 3, Loss 0.10, Accuracy 96.40
Epoch 4, Loss 0.09, Accuracy 96.78
Epoch 5, Loss 0.08, Accuracy 96.99
Epoch 6, Loss 0.07, Accuracy 97.11
Epoch 7, Loss 0.07, Accuracy 97.63
Epoch 8, Loss 0.06, Accuracy 97.73
Epoch 9, Loss 0.06, Accuracy 97.80
Epoch 10, Loss 0.05, Accuracy 98.00


#### Tracking training loss across replicas

> We do not recommend using tf.metrics.Mean to track the training loss across different replicas, because of the loss scaling computation that is carried out.

> For example, if you run a training job with the following characteristics: * Two replicas * Two samples are processed on each replica * Resulting loss values: [2, 3] and [4, 5] on each replica * Global batch size = 4