<a href="https://colab.research.google.com/github/visahan1/Tensorflow/blob/main/DsitributedTraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Custom training with tf.distribute.Strategy

We will implement a distribution strategy to train on the [Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102) dataset. As the name suggests, distribution strategies allow you to setup training across multiple devices. We are just using a single device in this lab but the syntax you'll apply should also work when you have a multi-device setup. Let's begin!

In [1]:
import tensorflow as tf
import tensorflow_hub as hub

# Helper libraries
import numpy as np
import os
from tqdm import tqdm

In [3]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [4]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']

(train_examples, validation_examples, test_examples), info = tfds.load('oxford_flowers102', with_info=True, as_supervised=True, split = splits, data_dir='data/')

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

[1mDownloading and preparing dataset oxford_flowers102/2.1.1 (download: 328.90 MiB, generated: 331.34 MiB, total: 660.25 MiB) to data/oxford_flowers102/2.1.1...[0m
Shuffling and writing examples to data/oxford_flowers102/2.1.1.incompleteNQ5AHV/oxford_flowers102-train.tfrecord
Shuffling and writing examples to data/oxford_flowers102/2.1.1.incompleteNQ5AHV/oxford_flowers102-test.tfrecord
Shuffling and writing examples to data/oxford_flowers102/2.1.1.incompleteNQ5AHV/oxford_flowers102-validation.tfrecord
[1mDataset oxford_flowers102 downloaded and prepared to data/oxford_flowers102/2.1.1. Subsequent calls will reuse this data.[0m


How does `tf.distribute.MirroredStrategy` strategy work?

*   All the variables and the model graph are replicated on the replicas.
*   Input is evenly distributed across the replicas.
*   Each replica calculates the loss and gradients for the input it received.
*   The gradients are synced across all the replicas by summing them.
*   After the sync, the same update is made to the copies of the variables on each replica.

In [5]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [6]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Setup input pipeline

Set some constants, including the buffer size, number of epochs, and the image size.

In [28]:
BUFFER_SIZE = num_examples
EPOCHS = 10
pixels = 224
MODULE_HANDLE = "https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/5"
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

Using https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/5 with input size (224, 224)


Define a function to format the image (resizes the image and scales the pixel values to range from [0,1].

In [8]:
def format_image(image, label):
    image = tf.image.resize(image, IMAGE_SIZE) / 255.0
    return  image, label

In [10]:
def set_global_batch_size(batch_size_per_replica, strategy):
    '''
    Args:
        batch_size_per_replica (int) - batch size per replica
        strategy (tf.distribute.Strategy) - distribution strategy
    '''
    global_batch_size = strategy.num_replicas_in_sync * batch_size_per_replica
    return global_batch_size

In [11]:
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = set_global_batch_size(BATCH_SIZE_PER_REPLICA, strategy)

print(GLOBAL_BATCH_SIZE)

64


Create the datasets using the global batch size and distribute the batches for training, validation and test batches

In [12]:
train_batches = train_examples.shuffle(num_examples).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

## Define the distributed datasets 

---



In [13]:
def distribute_datasets(strategy, train_batches, validation_batches, test_batches):
    
    train_dist_dataset = strategy.experimental_distribute_dataset(train_batches)
    val_dist_dataset = strategy.experimental_distribute_dataset(validation_batches)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_batches)
    
    return train_dist_dataset, val_dist_dataset, test_dist_dataset

In [14]:
train_dist_dataset, val_dist_dataset, test_dist_dataset = distribute_datasets(strategy, train_batches, validation_batches, test_batches)

In [15]:
print(type(train_dist_dataset))
print(type(val_dist_dataset))
print(type(test_dist_dataset))

<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>


In [16]:
# Take a look at a single batch from the train_dist_dataset
x = iter(train_dist_dataset).get_next()
    
print(f"x is a tuple that contains {len(x)} values ")
print(f"x[0] contains the features, and has shape {x[0].shape}")
print(f"  so it has {x[0].shape[0]} examples in the batch, each is an image that is {x[0].shape[1:]}")
print(f"x[1] contains the labels, and has shape {x[1].shape}")

x is a tuple that contains 2 values 
x[0] contains the features, and has shape (64, 224, 224, 3)
  so it has 64 examples in the batch, each is an image that is (224, 224, 3)
x[1] contains the labels, and has shape (64,)


## Create the model

Use the Model Subclassing API to create model `ResNetModel` as a subclass of `tf.keras.Model`.

In [29]:
class ResNetModel(tf.keras.Model):
    def __init__(self, classes):
        super(ResNetModel, self).__init__()
        self._feature_extractor = hub.KerasLayer(MODULE_HANDLE,
                                                 trainable=False) 
        self._classifier = tf.keras.layers.Dense(classes, activation='softmax')

    def call(self, inputs):
        x = self._feature_extractor(inputs)
        x = self._classifier(x)
        return x

In [19]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

You'll define the `loss_object` and `compute_loss` within the `strategy.scope()`.
- `loss_object` will be used later to calculate the loss on the test set.
- `compute_loss` will be used later to calculate the average loss on the training data.

You will be using these two loss calculations later.

In [20]:
with strategy.scope():
    # Set reduction to `NONE` so we can do the reduction afterwards and divide by
    # global batch size.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction=tf.keras.losses.Reduction.NONE)
    # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss = tf.keras.metrics.Mean(name='test_loss')

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


## Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. 
- You can use `.result()` to get the accumulated statistics at any time, for example, `train_accuracy.result()`.

In [21]:
with strategy.scope():
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [30]:
# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
    model = ResNetModel(classes=num_classes)
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

## Training loop
You will define a regular training step and test step, which could work without a distributed strategy.  You can then use `strategy.run` to apply these functions in a distributed manner.
- Notice that you'll define `train_step` and `test_step` inside another function `train_testp_step_fns`, which will then return these two functions.

### Define train_step
Within the strategy's scope, define `train_step(inputs)`
- `inputs` will be a tuple containing `(images, labels)`.
- Create a gradient tape block.
- Within the gradient tape block: 
  - Call the model, passing in the images and setting training to be `True` (complete this part).
  - Call the `compute_loss` function (defined earlier) to compute the training loss (complete this part).
  - Use the gradient tape to calculate the gradients.
  - Use the optimizer to update the weights using the gradients.
  
### Define test_step
Also within the strategy's scope, define `test_step(inputs)`
- `inputs` is a tuple containing `(images, labels)`.
  - Call the model, passing in the images and set training to `False`, because the model is not going to train on the test data. (complete this part).
  - Use the `loss_object`, which will compute the test loss.  Check `compute_loss`, defined earlier, to see what parameters to pass into `loss_object`. (complete this part).
  - Next, update `test_loss` (the running test loss) with the `t_loss` (the loss for the current batch).
  - Also update the `test_accuracy`.

In [54]:
def train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        def train_step(inputs):
            images, labels = inputs

            with tf.GradientTape() as tape:
                predictions = model(images,training=True)
                loss = compute_loss(labels,predictions)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            train_accuracy.update_state(labels, predictions)
            return loss 

        def test_step(inputs):
            images, labels = inputs
            predictions = model(images)
            t_loss = compute_loss(labels,predictions)
            test_loss.update_state(t_loss)
            test_accuracy.update_state(labels, predictions)
        
        return train_step, test_step

In [55]:
train_step, test_step = train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

In [56]:
def distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = strategy.run(train_step,args=(dataset_inputs,))
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                   axis=None) 

        @tf.function
        def distributed_test_step(dataset_inputs):
            return strategy.run(test_step,args=(dataset_inputs,))

        return distributed_train_step, distributed_test_step

Call the function that you just defined to get the distributed train step function and distributed test step function.

In [57]:
distributed_train_step, distributed_test_step = distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

In [58]:
# Running this cell in Coursera takes around 20 mins
with strategy.scope():
    for epoch in range(EPOCHS):
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in tqdm(train_dist_dataset):
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                    "Test Accuracy: {}")
        print (template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, test_loss.result(),
                               test_accuracy.result()*100))

        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states()



0it [00:00, ?it/s][A[A

1it [00:03,  3.09s/it][A[A

2it [00:03,  2.21s/it][A[A

3it [00:03,  1.59s/it][A[A

4it [00:03,  1.16s/it][A[A

5it [00:03,  1.17it/s][A[A

6it [00:03,  1.55it/s][A[A

7it [00:04,  2.02it/s][A[A

8it [00:04,  2.56it/s][A[A

9it [00:04,  3.13it/s][A[A

10it [00:04,  3.73it/s][A[A

11it [00:04,  4.31it/s][A[A

12it [00:04,  4.85it/s][A[A

13it [00:10,  1.26it/s]


0it [00:00, ?it/s][A[A

Epoch 1, Loss: 0.7265761494636536, Accuracy: 62.5, Test Loss: 0.031014518812298775, Test Accuracy: 57.843135833740234




1it [00:01,  1.04s/it][A[A

2it [00:01,  1.30it/s][A[A

3it [00:01,  1.71it/s][A[A

4it [00:01,  2.20it/s][A[A

5it [00:01,  2.76it/s][A[A

6it [00:01,  3.35it/s][A[A

7it [00:01,  3.92it/s][A[A

8it [00:02,  4.49it/s][A[A

9it [00:02,  4.97it/s][A[A

10it [00:02,  5.39it/s][A[A

11it [00:02,  5.74it/s][A[A

12it [00:02,  6.02it/s][A[A

13it [00:02,  4.63it/s]


0it [00:00, ?it/s][A[A

Epoch 2, Loss: 0.6489842534065247, Accuracy: 95.5882339477539, Test Loss: 0.030527425929903984, Test Accuracy: 56.86274719238281




1it [00:01,  1.03s/it][A[A

2it [00:01,  1.31it/s][A[A

3it [00:01,  1.72it/s][A[A

4it [00:01,  2.22it/s][A[A

5it [00:01,  2.77it/s][A[A

6it [00:01,  3.35it/s][A[A

7it [00:01,  3.93it/s][A[A

8it [00:02,  4.50it/s][A[A

9it [00:02,  4.98it/s][A[A

10it [00:02,  5.40it/s][A[A

11it [00:02,  5.70it/s][A[A

12it [00:02,  6.00it/s][A[A

13it [00:02,  4.63it/s]


0it [00:00, ?it/s][A[A

Epoch 3, Loss: 0.6279276609420776, Accuracy: 95.71078491210938, Test Loss: 0.030340885743498802, Test Accuracy: 56.86274719238281




1it [00:01,  1.03s/it][A[A

2it [00:01,  1.30it/s][A[A

3it [00:01,  1.72it/s][A[A

4it [00:01,  2.21it/s][A[A

5it [00:01,  2.76it/s][A[A

6it [00:01,  3.34it/s][A[A

7it [00:01,  3.93it/s][A[A

8it [00:02,  4.47it/s][A[A

9it [00:02,  4.93it/s][A[A

10it [00:02,  5.38it/s][A[A

11it [00:02,  5.72it/s][A[A

12it [00:02,  6.01it/s][A[A

13it [00:02,  4.62it/s]


0it [00:00, ?it/s][A[A

Epoch 4, Loss: 0.6181926131248474, Accuracy: 95.83332824707031, Test Loss: 0.03023962676525116, Test Accuracy: 57.843135833740234




1it [00:01,  1.04s/it][A[A

2it [00:01,  1.29it/s][A[A

3it [00:01,  1.70it/s][A[A

4it [00:01,  2.20it/s][A[A

5it [00:01,  2.75it/s][A[A

6it [00:01,  3.32it/s][A[A

7it [00:01,  3.91it/s][A[A

8it [00:02,  4.48it/s][A[A

9it [00:02,  5.01it/s][A[A

10it [00:02,  5.43it/s][A[A

11it [00:02,  5.77it/s][A[A

12it [00:02,  6.08it/s][A[A

13it [00:02,  4.63it/s]


0it [00:00, ?it/s][A[A

Epoch 5, Loss: 0.611754298210144, Accuracy: 95.95588684082031, Test Loss: 0.030147826299071312, Test Accuracy: 57.843135833740234




1it [00:01,  1.04s/it][A[A

2it [00:01,  1.29it/s][A[A

3it [00:01,  1.70it/s][A[A

4it [00:01,  2.19it/s][A[A

5it [00:01,  2.74it/s][A[A

6it [00:01,  3.31it/s][A[A

7it [00:01,  3.89it/s][A[A

8it [00:02,  4.44it/s][A[A

9it [00:02,  4.92it/s][A[A

10it [00:02,  5.34it/s][A[A

11it [00:02,  5.65it/s][A[A

12it [00:02,  5.95it/s][A[A

13it [00:02,  4.58it/s]


0it [00:00, ?it/s][A[A

Epoch 6, Loss: 0.6059982776641846, Accuracy: 96.20098114013672, Test Loss: 0.030070481821894646, Test Accuracy: 57.843135833740234




1it [00:01,  1.03s/it][A[A

2it [00:01,  1.30it/s][A[A

3it [00:01,  1.71it/s][A[A

4it [00:01,  2.21it/s][A[A

5it [00:01,  2.76it/s][A[A

6it [00:01,  3.32it/s][A[A

7it [00:01,  3.90it/s][A[A

8it [00:02,  4.47it/s][A[A

9it [00:02,  4.94it/s][A[A

10it [00:02,  5.39it/s][A[A

11it [00:02,  5.75it/s][A[A

12it [00:02,  6.04it/s][A[A

13it [00:05,  2.52it/s]


0it [00:00, ?it/s][A[A

Epoch 7, Loss: 0.6003167629241943, Accuracy: 96.32353210449219, Test Loss: 0.029993543401360512, Test Accuracy: 57.843135833740234




1it [00:01,  1.04s/it][A[A

2it [00:01,  1.29it/s][A[A

3it [00:01,  1.70it/s][A[A

4it [00:01,  2.19it/s][A[A

5it [00:01,  2.74it/s][A[A

6it [00:01,  3.30it/s][A[A

7it [00:01,  3.88it/s][A[A

8it [00:02,  4.44it/s][A[A

9it [00:02,  4.90it/s][A[A

10it [00:02,  5.32it/s][A[A

11it [00:02,  5.68it/s][A[A

12it [00:02,  5.96it/s][A[A

13it [00:02,  4.57it/s]


0it [00:00, ?it/s][A[A

Epoch 8, Loss: 0.5949960947036743, Accuracy: 96.32353210449219, Test Loss: 0.029914744198322296, Test Accuracy: 57.843135833740234




1it [00:01,  1.05s/it][A[A

2it [00:01,  1.28it/s][A[A

3it [00:01,  1.69it/s][A[A

4it [00:01,  2.18it/s][A[A

5it [00:01,  2.73it/s][A[A

6it [00:01,  3.29it/s][A[A

7it [00:01,  3.87it/s][A[A

8it [00:02,  4.44it/s][A[A

9it [00:02,  4.95it/s][A[A

10it [00:02,  5.38it/s][A[A

11it [00:02,  5.72it/s][A[A

12it [00:02,  6.02it/s][A[A

13it [00:02,  4.59it/s]


0it [00:00, ?it/s][A[A

Epoch 9, Loss: 0.5895390510559082, Accuracy: 96.5686264038086, Test Loss: 0.029833536595106125, Test Accuracy: 57.843135833740234




1it [00:01,  1.05s/it][A[A

2it [00:01,  1.28it/s][A[A

3it [00:01,  1.68it/s][A[A

4it [00:01,  2.16it/s][A[A

5it [00:01,  2.70it/s][A[A

6it [00:01,  3.27it/s][A[A

7it [00:01,  3.82it/s][A[A

8it [00:02,  4.39it/s][A[A

9it [00:02,  4.87it/s][A[A

10it [00:02,  5.32it/s][A[A

11it [00:02,  5.66it/s][A[A

12it [00:02,  5.92it/s][A[A

13it [00:02,  4.53it/s]


Epoch 10, Loss: 0.5841795206069946, Accuracy: 96.5686264038086, Test Loss: 0.02976243570446968, Test Accuracy: 57.843135833740234
