# Week 4 Assignment: Custom training with tf.distribute.Strategy

## Imports

In [39]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import tensorflow_hub as hub

# Helper libraries
import numpy as np
import os
from tqdm import tqdm

## Download the dataset

In [40]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [41]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']

(train_examples, validation_examples, test_examples), info = tfds.load('oxford_flowers102', with_info=True, as_supervised=True, split = splits, data_dir='data/')

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

## Create a strategy to distribute the variables and the graph

In [42]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [43]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Setup input pipeline

In [44]:
BUFFER_SIZE = num_examples
EPOCHS = 10
pixels = 224
MODULE_HANDLE='https://tfhub.dev/tensorflow/resnet_50/feature_vector/1'
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

Using https://tfhub.dev/tensorflow/resnet_50/feature_vector/1 with input size (224, 224)


In [45]:
def format_image(image, label):
    image = tf.image.resize(image, IMAGE_SIZE) / 255.0
    return  image, label

In [46]:
# GRADED FUNCTION
def set_global_batch_size(batch_size_per_replica, strategy):
    '''
    Args:
        batch_size_per_replica (int) - batch size per replica
        strategy (tf.distribute.Strategy) - distribution strategy
    '''
    
    # set the global batch size
    ### START CODE HERE ###
    global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
    ### END CODD HERE ###
    
    return global_batch_size

In [47]:
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = set_global_batch_size(BATCH_SIZE_PER_REPLICA, strategy)

print(GLOBAL_BATCH_SIZE)

64


**Expected Output:**
```
64
```

In [48]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

## Define the distributed datasets

In [49]:
# GRADED FUNCTION
def distribute_datasets(strategy, train_batches, validation_batches, test_batches):
    
    ### START CODE HERE ###
    train_dist_dataset = strategy.experimental_distribute_dataset(train_batches)
    val_dist_dataset = strategy.experimental_distribute_dataset(validation_batches)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_batches)
    ### END CODE HERE ###
    
    return train_dist_dataset, val_dist_dataset, test_dist_dataset

In [50]:
train_dist_dataset, val_dist_dataset, test_dist_dataset = distribute_datasets(strategy, train_batches, validation_batches, test_batches)

In [51]:
print(type(train_dist_dataset))
print(type(val_dist_dataset))
print(type(test_dist_dataset))

<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>


**Expected Output:**
```
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
```

In [52]:
# Take a look at a single batch from the train_dist_dataset
for x in train_dist_dataset:
    # do nothing, just get one batch x
    break
    
print(f"x is a tuple that contains {len(x)} values ")
print(f"x[0] contains the features, and has shape {x[0].shape}")
print(f"  so it has {x[0].shape[0]} examples in the batch, each is an image that is {x[0].shape[1:]}")
print(f"x[1] contains the labels, and has shape {x[1].shape}")

x is a tuple that contains 2 values 
x[0] contains the features, and has shape (64, 224, 224, 3)
  so it has 64 examples in the batch, each is an image that is (224, 224, 3)
x[1] contains the labels, and has shape (64,)


## Create the model

In [53]:
class ResNetModel(tf.keras.Model):
    def __init__(self, classes):
        super(ResNetModel, self).__init__()
        self._feature_extractor = hub.KerasLayer(MODULE_HANDLE,
                                                 trainable=False) 
        self._classifier = tf.keras.layers.Dense(classes, activation='softmax')

    def call(self, inputs):
        x = self._feature_extractor(inputs)
        x = self._classifier(x)
        return x

In [54]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

In [55]:
with strategy.scope():
    # Set reduction to `NONE` so we can do the reduction afterwards and divide by
    # global batch size.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction=tf.keras.losses.Reduction.NONE)
    # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss = tf.keras.metrics.Mean(name='test_loss')

## Define the metrics to track loss and accuracy

In [56]:
with strategy.scope():
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')

## Instantiate the model, optimizer, and checkpoints

In [57]:
# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
    model = ResNetModel(classes=num_classes)
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

## Training loop (please complete this section)

You will define a regular training step and test step, which could work without a distributed strategy.  You can then use `strategy.run` to apply these functions in a distributed manner.
- Notice that you'll define `train_step` and `test_step` inside another function `train_testp_step_fns`, which will then return these two functions.

### Define train_step
Within the strategy's scope, define `train_step(inputs)`
- `inputs` will be a tuple containing `(images, labels)`.
- Create a gradient tape block.
- Within the gradient tape block: 
  - Call the model, passing in the images and setting training to be `True` (complete this part).
  - Call the `compute_loss` function (defined earlier) to compute the training loss (complete this part).
  - Use the gradient tape to calculate the gradients.
  - Use the optimizer to update the weights using the gradients.
  
### Define test_step
Also within the strategy's scope, define `test_step(inputs)`
- `inputs` is a tuple containing `(images, labels)`.
  - Call the model, passing in the images and set training to `False`, because the model is not going to train on the test data. (complete this part).
  - Use the `loss_object`, which will compute the test loss.  Check `compute_loss`, defined earlier, to see what parameters to pass into `loss_object`. (complete this part).
  - Next, update `test_loss` (the running test loss) with the `t_loss` (the loss for the current batch).
  - Also update the `test_accuracy`.

In [58]:
# GRADED FUNCTION
def train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        def train_step(inputs):
            images, labels = inputs

            with tf.GradientTape() as tape:
                ### START CODE HERE ###
                predictions = model(images, training= True)
                loss = compute_loss(labels, predictions)
                ### END CODE HERE ###

            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            train_accuracy.update_state(labels, predictions)
            return loss 

        def test_step(inputs):
            images, labels = inputs
            
            ### START CODE HERE ###
            predictions = model(images, training= False)
            t_loss = compute_loss(labels, predictions)
            ### END CODE HERE ###

            test_loss.update_state(t_loss)
            test_accuracy.update_state(labels, predictions)
        
        return train_step, test_step

In [59]:
train_step, test_step = train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

## Distributed training and testing (please complete this section)

The `train_step` and `test_step` could be used in a non-distributed, regular model training.  To apply them in a distributed way, you'll use [strategy.run](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run).

`distributed_train_step`
- Call the `run` function of the `strategy`, passing in the train step function (which you defined earlier), as well as the arguments that go in the train step function.
- The run function is defined like this `run(fn, args=() )`.  
  - `args` will take in the dataset inputs

`distributed_test_step`
- Similar to training, the distributed test step will use the `run` function of your strategy, taking in the test step function as well as the dataset inputs that go into the test step function.

#### Hint:
- You saw earlier that each batch in `train_dist_dataset` is tuple with two values:
  - a batch of features
  - a batch of labels.

Let's think about how you'll want to pass in the dataset inputs into `args` by running this next cell of code:

In [60]:
#See various ways of passing in the inputs 

def fun1(args=()):
    print(f"number of arguments passed is {len(args)}")
    
    
list_of_inputs = [1,2]
print("When passing in args=list_of_inputs:")
fun1(args=list_of_inputs)
print()
print("When passing in args=(list_of_inputs)")
fun1(args=(list_of_inputs))
print()
print("When passing in args=(list_of_inputs,)")
fun1(args=(list_of_inputs,))

When passing in args=list_of_inputs:
number of arguments passed is 2

When passing in args=(list_of_inputs)
number of arguments passed is 2

When passing in args=(list_of_inputs,)
number of arguments passed is 1


Notice that depending on how `list_of_inputs` is passed to `args` affects whether `fun1` sees one or two positional arguments.  
- If you see an error message about positional arguments when running the training code later, please come back to check how you're passing in the inputs to `run`.

Please complete the following function.

In [61]:
def distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        @tf.function
        def distributed_train_step(dataset_inputs):
            ### START CODE HERE ###
            per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
            ### END CODE HERE ###
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                   axis=None)

        @tf.function
        def distributed_test_step(dataset_inputs):
            ### START CODE HERE ###
            return strategy.run(test_step, args=(dataset_inputs,))
            ### END CODE HERE ###
    
        return distributed_train_step, distributed_test_step

In [62]:
distributed_train_step, distributed_test_step = distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

In [64]:
# Running this cell in Coursera takes around 20 mins
with strategy.scope():
    for epoch in range(EPOCHS):
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in tqdm(train_dist_dataset):
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                    "Test Accuracy: {}")
        print (template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, test_loss.result(),
                               test_accuracy.result()*100))

        test_loss.reset_state()
        train_accuracy.reset_state()
        test_accuracy.reset_state()

13it [00:16,  1.29s/it]


Epoch 1, Loss: 2.387317657470703, Accuracy: 31.617647171020508, Test Loss: 0.04996610805392265, Test Accuracy: 31.86274528503418


13it [00:17,  1.31s/it]


Epoch 2, Loss: 1.3021883964538574, Accuracy: 87.25489807128906, Test Loss: 0.03244062140583992, Test Accuracy: 54.90196228027344


13it [00:16,  1.30s/it]


Epoch 3, Loss: 0.7629187703132629, Accuracy: 94.85294342041016, Test Loss: 0.027234315872192383, Test Accuracy: 65.68627166748047


13it [00:16,  1.27s/it]


Epoch 4, Loss: 0.4917829632759094, Accuracy: 96.81372833251953, Test Loss: 0.025203092023730278, Test Accuracy: 64.70588684082031


13it [00:16,  1.31s/it]


Epoch 5, Loss: 0.34504762291908264, Accuracy: 98.52941131591797, Test Loss: 0.02263624221086502, Test Accuracy: 69.60784149169922


13it [00:16,  1.29s/it]


Epoch 6, Loss: 0.2561875581741333, Accuracy: 99.63235473632812, Test Loss: 0.02171161212027073, Test Accuracy: 68.62745666503906


13it [00:16,  1.25s/it]


Epoch 7, Loss: 0.20192688703536987, Accuracy: 99.63235473632812, Test Loss: 0.020747823640704155, Test Accuracy: 68.62745666503906


13it [00:16,  1.26s/it]


Epoch 8, Loss: 0.16083525121212006, Accuracy: 99.87745666503906, Test Loss: 0.0201865267008543, Test Accuracy: 67.64705657958984


13it [00:16,  1.28s/it]


Epoch 9, Loss: 0.1339055597782135, Accuracy: 99.87745666503906, Test Loss: 0.019844889640808105, Test Accuracy: 67.64705657958984


13it [00:16,  1.29s/it]


Epoch 10, Loss: 0.11087200045585632, Accuracy: 99.87745666503906, Test Loss: 0.01928929053246975, Test Accuracy: 67.64705657958984
