##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Distributed training with Keras

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/keras"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/keras.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/keras.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/keras.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Overview

The `tf.distribute.Strategy` API provides an abstraction for distributing your training across multiple processing units. It allows you to carry out distributed training using existing models and training code with minimal changes.

This tutorial demonstrates how to use the `tf.distribute.MirroredStrategy` to perform in-graph replication with _synchronous training on many GPUs on one machine_. The strategy essentially copies all of the model's variables to each processor. Then, it uses [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) to combine the gradients from all processors, and applies the combined value to all copies of the model.

You will use the `tf.keras` APIs to build the model and `Model.fit` for training it. (To learn about distributed training with a custom training loop and the `MirroredStrategy`, check out [this tutorial](custom_training.ipynb).)

`MirroredStrategy` trains your model on multiple GPUs on a single machine. For _synchronous training on many GPUs on multiple workers_, use the `tf.distribute.MultiWorkerMirroredStrategy` [with the Keras Model.fit](multi_worker_with_keras.ipynb) or [a custom training loop](multi_worker_with_ctl.ipynb). For other options, refer to the [Distributed training guide](../../guide/distributed_training.ipynb).

To learn about various other strategies, there is the [Distributed training with TensorFlow](../../guide/distributed_training.ipynb) guide.

## Setup

In [2]:
import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.
%load_ext tensorboard

2021-08-04 01:24:55.165631: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [3]:
print(tf.__version__)

2.5.0


## Download the dataset

Load the MNIST dataset from [TensorFlow Datasets](https://www.tensorflow.org/datasets). This returns a dataset in the `tf.data` format.

Setting the `with_info` argument to `True` includes the metadata for the entire dataset, which is being saved here to `info`. Among other things, this metadata object includes the number of train and test examples.

In [4]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

2021-08-04 01:25:00.048530: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-08-04 01:25:00.691099: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-04 01:25:00.691993: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-08-04 01:25:00.692033: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-08-04 01:25:00.695439: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-08-04 01:25:00.695536: I tensorflow/stream_executor/platfo

## Define the distribution strategy

Create a `MirroredStrategy` object. This will handle distribution and provide a context manager (`MirroredStrategy.scope`) to build your model inside.

In [5]:
strategy = tf.distribute.MirroredStrategy()





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [6]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Set up the input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory and tune the learning rate accordingly.

In [7]:
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Define a function that normalizes the image pixel values from the `[0, 255]` range to the  `[0, 1]` range ([feature scaling](https://en.wikipedia.org/wiki/Feature_scaling)):

In [8]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Apply this `scale` function to the training and test data, and then use the `tf.data.Dataset` APIs to shuffle the training data (`Dataset.shuffle`), and batch it (`Dataset.batch`). Notice that you are also keeping an in-memory cache of the training data to improve performance (`Dataset.cache`).

In [9]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

## Create the model

Create and compile the Keras model in the context of `Strategy.scope`:

In [10]:
with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


## Define the callbacks


Define the following `tf.keras.callbacks`:

- `tf.keras.callbacks.TensorBoard`: writes a log for TensorBoard, which allows you to visualize the graphs.
- `tf.keras.callbacks.ModelCheckpoint`: saves the model at a certain frequency, such as after every epoch.
- `tf.keras.callbacks.LearningRateScheduler`: schedules the learning rate to change after, for example, every epoch/batch.

For illustrative purposes, add a custom callback called `PrintLR` to display the *learning rate* in the notebook.

In [11]:
# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [12]:
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [13]:
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

In [14]:
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

2021-08-04 01:25:02.054144: I tensorflow/core/profiler/lib/profiler_session.cc:126] Profiler session initializing.
2021-08-04 01:25:02.054179: I tensorflow/core/profiler/lib/profiler_session.cc:141] Profiler session started.
2021-08-04 01:25:02.054232: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1611] Profiler found 1 GPUs
2021-08-04 01:25:02.098001: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcupti.so.11.2


2021-08-04 01:25:02.288095: I tensorflow/core/profiler/lib/profiler_session.cc:159] Profiler session tear down.


2021-08-04 01:25:02.292220: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1743] CUPTI activity buffer flushed


## Train and evaluate

Now, train the model in the usual way by calling `Model.fit` on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

In [15]:
EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)

2021-08-04 01:25:02.342811: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2021-08-04 01:25:02.389307: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-08-04 01:25:02.389734: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000179999 Hz


Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


2021-08-04 01:25:05.851687: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8


2021-08-04 01:25:07.965516: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8100


2021-08-04 01:25:13.166255: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11


2021-08-04 01:25:13.566160: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11



  1/938 [..............................] - ETA: 3:09:47 - loss: 2.2850 - accuracy: 0.1094

2021-08-04 01:25:14.615346: I tensorflow/core/profiler/lib/profiler_session.cc:126] Profiler session initializing.
2021-08-04 01:25:14.615388: I tensorflow/core/profiler/lib/profiler_session.cc:141] Profiler session started.



  2/938 [..............................] - ETA: 7:27 - loss: 2.2403 - accuracy: 0.2188   


  3/938 [..............................] - ETA: 4:21 - loss: 2.1694 - accuracy: 0.3333



2021-08-04 01:25:15.082713: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2021-08-04 01:25:15.085886: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1743] CUPTI activity buffer flushed
2021-08-04 01:25:15.122453: I tensorflow/core/profiler/internal/gpu/cupti_collector.cc:673]  GpuTracer has collected 96 callback api events and 93 activity events. 
2021-08-04 01:25:15.126946: I tensorflow/core/profiler/lib/profiler_session.cc:159] Profiler session tear down.
2021-08-04 01:25:15.138108: I tensorflow/core/profiler/rpc/client/save_profile.cc:137] Creating directory: ./logs/train/plugins/profile/2021_08_04_01_25_15
2021-08-04 01:25:15.146767: I tensorflow/core/profiler/rpc/client/save_profile.cc:143] Dumped gzipped tool data for trace.json.gz to ./logs/train/plugins/profile/2021_08_04_01_25_15/kokoro-gcp-ubuntu-prod-1251741625.trace.json.gz
2021-08-04 01:25:15.154434: I tensorflow/core/profiler/rpc/client/save_profile.cc:137] Creating dire






 16/938 [..............................] - ETA: 37s - loss: 1.5066 - accuracy: 0.6318 


 29/938 [..............................] - ETA: 21s - loss: 1.1650 - accuracy: 0.6913


 43/938 [>.............................] - ETA: 15s - loss: 0.9408 - accuracy: 0.7424


 57/938 [>.............................] - ETA: 12s - loss: 0.8171 - accuracy: 0.7749


 71/938 [=>............................] - ETA: 10s - loss: 0.7317 - accuracy: 0.7934


 85/938 [=>............................] - ETA: 8s - loss: 0.6709 - accuracy: 0.8074 


 99/938 [==>...........................] - ETA: 7s - loss: 0.6220 - accuracy: 0.8202


113/938 [==>...........................] - ETA: 7s - loss: 0.5829 - accuracy: 0.8301


127/938 [===>..........................] - ETA: 6s - loss: 0.5569 - accuracy: 0.8383


141/938 [===>..........................] - ETA: 6s - loss: 0.5306 - accuracy: 0.8460


155/938 [===>..........................] - ETA: 5s - loss: 0.5042 - accuracy: 0.8541


169/938 [====>.........................] - ETA: 5s - loss: 0.4820 - accuracy: 0.8595


183/938 [====>.........................] - ETA: 5s - loss: 0.4664 - accuracy: 0.8639


198/938 [=====>........................] - ETA: 4s - loss: 0.4465 - accuracy: 0.8696


212/938 [=====>........................] - ETA: 4s - loss: 0.4336 - accuracy: 0.8734





































































































Learning rate for epoch 1 is 0.0010000000474974513


Epoch 2/12

  1/938 [..............................] - ETA: 16s - loss: 0.0625 - accuracy: 0.9688


 19/938 [..............................] - ETA: 2s - loss: 0.0787 - accuracy: 0.9762 


 37/938 [>.............................] - ETA: 2s - loss: 0.0888 - accuracy: 0.9751


 56/938 [>.............................] - ETA: 2s - loss: 0.0974 - accuracy: 0.9735


 74/938 [=>............................] - ETA: 2s - loss: 0.0986 - accuracy: 0.9726


 92/938 [=>............................] - ETA: 2s - loss: 0.0930 - accuracy: 0.9740


110/938 [==>...........................] - ETA: 2s - loss: 0.0912 - accuracy: 0.9739


129/938 [===>..........................] - ETA: 2s - loss: 0.0908 - accuracy: 0.9740


147/938 [===>..........................] - ETA: 2s - loss: 0.0884 - accuracy: 0.9749


166/938 [====>.........................] - ETA: 2s - loss: 0.0880 - accuracy: 0.9754


185/938 [====>.........................] - ETA: 2s - loss: 0.0852 - accuracy: 0.9759


204/938 [=====>........................] - ETA: 2s - loss: 0.0829 - accuracy: 0.9765





















































































Learning rate for epoch 2 is 0.0010000000474974513


Epoch 3/12

  1/938 [..............................] - ETA: 16s - loss: 0.0366 - accuracy: 0.9844


 18/938 [..............................] - ETA: 2s - loss: 0.0364 - accuracy: 0.9887 


 36/938 [>.............................] - ETA: 2s - loss: 0.0474 - accuracy: 0.9861


 54/938 [>.............................] - ETA: 2s - loss: 0.0503 - accuracy: 0.9855


 71/938 [=>............................] - ETA: 2s - loss: 0.0464 - accuracy: 0.9866


 89/938 [=>............................] - ETA: 2s - loss: 0.0480 - accuracy: 0.9860


106/938 [==>...........................] - ETA: 2s - loss: 0.0492 - accuracy: 0.9854


123/938 [==>...........................] - ETA: 2s - loss: 0.0507 - accuracy: 0.9855


141/938 [===>..........................] - ETA: 2s - loss: 0.0510 - accuracy: 0.9853


158/938 [====>.........................] - ETA: 2s - loss: 0.0491 - accuracy: 0.9861


176/938 [====>.........................] - ETA: 2s - loss: 0.0502 - accuracy: 0.9852


194/938 [=====>........................] - ETA: 2s - loss: 0.0495 - accuracy: 0.9854


212/938 [=====>........................] - ETA: 2s - loss: 0.0507 - accuracy: 0.9852























































































Learning rate for epoch 3 is 0.0010000000474974513


Epoch 4/12

  1/938 [..............................] - ETA: 15s - loss: 0.0183 - accuracy: 1.0000


 18/938 [..............................] - ETA: 2s - loss: 0.0357 - accuracy: 0.9931 


 35/938 [>.............................] - ETA: 2s - loss: 0.0323 - accuracy: 0.9920


 52/938 [>.............................] - ETA: 2s - loss: 0.0300 - accuracy: 0.9919


 70/938 [=>............................] - ETA: 2s - loss: 0.0288 - accuracy: 0.9922


 88/938 [=>............................] - ETA: 2s - loss: 0.0275 - accuracy: 0.9922


106/938 [==>...........................] - ETA: 2s - loss: 0.0279 - accuracy: 0.9923


124/938 [==>...........................] - ETA: 2s - loss: 0.0308 - accuracy: 0.9917


142/938 [===>..........................] - ETA: 2s - loss: 0.0307 - accuracy: 0.9916


160/938 [====>.........................] - ETA: 2s - loss: 0.0308 - accuracy: 0.9913


178/938 [====>.........................] - ETA: 2s - loss: 0.0304 - accuracy: 0.9913


196/938 [=====>........................] - ETA: 2s - loss: 0.0305 - accuracy: 0.9914


214/938 [=====>........................] - ETA: 2s - loss: 0.0309 - accuracy: 0.9912





















































































Learning rate for epoch 4 is 9.999999747378752e-05


Epoch 5/12

  1/938 [..............................] - ETA: 15s - loss: 0.0030 - accuracy: 1.0000


 19/938 [..............................] - ETA: 2s - loss: 0.0188 - accuracy: 0.9942 


 37/938 [>.............................] - ETA: 2s - loss: 0.0217 - accuracy: 0.9941


 54/938 [>.............................] - ETA: 2s - loss: 0.0195 - accuracy: 0.9948


 71/938 [=>............................] - ETA: 2s - loss: 0.0206 - accuracy: 0.9943


 89/938 [=>............................] - ETA: 2s - loss: 0.0243 - accuracy: 0.9933


107/938 [==>...........................] - ETA: 2s - loss: 0.0257 - accuracy: 0.9933


125/938 [==>...........................] - ETA: 2s - loss: 0.0245 - accuracy: 0.9934


143/938 [===>..........................] - ETA: 2s - loss: 0.0242 - accuracy: 0.9933


161/938 [====>.........................] - ETA: 2s - loss: 0.0243 - accuracy: 0.9934


179/938 [====>.........................] - ETA: 2s - loss: 0.0256 - accuracy: 0.9928


197/938 [=====>........................] - ETA: 2s - loss: 0.0255 - accuracy: 0.9926


215/938 [=====>........................] - ETA: 2s - loss: 0.0252 - accuracy: 0.9928





















































































Learning rate for epoch 5 is 9.999999747378752e-05


Epoch 6/12

  1/938 [..............................] - ETA: 15s - loss: 0.0244 - accuracy: 0.9844


 19/938 [..............................] - ETA: 2s - loss: 0.0222 - accuracy: 0.9918 


 37/938 [>.............................] - ETA: 2s - loss: 0.0278 - accuracy: 0.9920


 56/938 [>.............................] - ETA: 2s - loss: 0.0257 - accuracy: 0.9922


 74/938 [=>............................] - ETA: 2s - loss: 0.0246 - accuracy: 0.9922


 90/938 [=>............................] - ETA: 2s - loss: 0.0242 - accuracy: 0.9927


107/938 [==>...........................] - ETA: 2s - loss: 0.0232 - accuracy: 0.9931


124/938 [==>...........................] - ETA: 2s - loss: 0.0226 - accuracy: 0.9933


141/938 [===>..........................] - ETA: 2s - loss: 0.0221 - accuracy: 0.9938


159/938 [====>.........................] - ETA: 2s - loss: 0.0229 - accuracy: 0.9936


177/938 [====>.........................] - ETA: 2s - loss: 0.0233 - accuracy: 0.9934


194/938 [=====>........................] - ETA: 2s - loss: 0.0237 - accuracy: 0.9935


212/938 [=====>........................] - ETA: 2s - loss: 0.0234 - accuracy: 0.9934





















































































Learning rate for epoch 6 is 9.999999747378752e-05


Epoch 7/12

  1/938 [..............................] - ETA: 16s - loss: 0.0050 - accuracy: 1.0000


 18/938 [..............................] - ETA: 2s - loss: 0.0196 - accuracy: 0.9931 


 36/938 [>.............................] - ETA: 2s - loss: 0.0233 - accuracy: 0.9905


 54/938 [>.............................] - ETA: 2s - loss: 0.0221 - accuracy: 0.9922


 73/938 [=>............................] - ETA: 2s - loss: 0.0213 - accuracy: 0.9932


 91/938 [=>............................] - ETA: 2s - loss: 0.0255 - accuracy: 0.9928


110/938 [==>...........................] - ETA: 2s - loss: 0.0244 - accuracy: 0.9936


129/938 [===>..........................] - ETA: 2s - loss: 0.0252 - accuracy: 0.9933


148/938 [===>..........................] - ETA: 2s - loss: 0.0241 - accuracy: 0.9936


166/938 [====>.........................] - ETA: 2s - loss: 0.0234 - accuracy: 0.9938


184/938 [====>.........................] - ETA: 2s - loss: 0.0239 - accuracy: 0.9935


202/938 [=====>........................] - ETA: 2s - loss: 0.0239 - accuracy: 0.9933





















































































Learning rate for epoch 7 is 9.999999747378752e-05


Epoch 8/12

  1/938 [..............................] - ETA: 15s - loss: 0.0035 - accuracy: 1.0000


 19/938 [..............................] - ETA: 2s - loss: 0.0237 - accuracy: 0.9934 


 37/938 [>.............................] - ETA: 2s - loss: 0.0234 - accuracy: 0.9932


 55/938 [>.............................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9932


 73/938 [=>............................] - ETA: 2s - loss: 0.0228 - accuracy: 0.9942


 91/938 [=>............................] - ETA: 2s - loss: 0.0223 - accuracy: 0.9942


108/938 [==>...........................] - ETA: 2s - loss: 0.0226 - accuracy: 0.9938


125/938 [==>...........................] - ETA: 2s - loss: 0.0220 - accuracy: 0.9940


144/938 [===>..........................] - ETA: 2s - loss: 0.0223 - accuracy: 0.9938


162/938 [====>.........................] - ETA: 2s - loss: 0.0210 - accuracy: 0.9941


180/938 [====>.........................] - ETA: 2s - loss: 0.0208 - accuracy: 0.9942


198/938 [=====>........................] - ETA: 2s - loss: 0.0206 - accuracy: 0.9943


216/938 [=====>........................] - ETA: 2s - loss: 0.0202 - accuracy: 0.9946



















































































Learning rate for epoch 8 is 9.999999747378752e-06


Epoch 9/12



  1/938 [..............................] - ETA: 15s - loss: 0.0040 - accuracy: 1.0000


 19/938 [..............................] - ETA: 2s - loss: 0.0230 - accuracy: 0.9918 


 37/938 [>.............................] - ETA: 2s - loss: 0.0203 - accuracy: 0.9949


 55/938 [>.............................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9960


 73/938 [=>............................] - ETA: 2s - loss: 0.0172 - accuracy: 0.9964


 91/938 [=>............................] - ETA: 2s - loss: 0.0189 - accuracy: 0.9952


109/938 [==>...........................] - ETA: 2s - loss: 0.0185 - accuracy: 0.9951


128/938 [===>..........................] - ETA: 2s - loss: 0.0186 - accuracy: 0.9950


147/938 [===>..........................] - ETA: 2s - loss: 0.0189 - accuracy: 0.9948


165/938 [====>.........................] - ETA: 2s - loss: 0.0185 - accuracy: 0.9949


182/938 [====>.........................] - ETA: 2s - loss: 0.0183 - accuracy: 0.9951


200/938 [=====>........................] - ETA: 2s - loss: 0.0179 - accuracy: 0.9951





















































































Learning rate for epoch 9 is 9.999999747378752e-06


Epoch 10/12

  1/938 [..............................] - ETA: 15s - loss: 0.0221 - accuracy: 1.0000


 19/938 [..............................] - ETA: 2s - loss: 0.0233 - accuracy: 0.9951 


 38/938 [>.............................] - ETA: 2s - loss: 0.0196 - accuracy: 0.9951


 56/938 [>.............................] - ETA: 2s - loss: 0.0196 - accuracy: 0.9950


 74/938 [=>............................] - ETA: 2s - loss: 0.0182 - accuracy: 0.9951


 92/938 [=>............................] - ETA: 2s - loss: 0.0178 - accuracy: 0.9958


110/938 [==>...........................] - ETA: 2s - loss: 0.0178 - accuracy: 0.9957


128/938 [===>..........................] - ETA: 2s - loss: 0.0189 - accuracy: 0.9956


145/938 [===>..........................] - ETA: 2s - loss: 0.0182 - accuracy: 0.9954


163/938 [====>.........................] - ETA: 2s - loss: 0.0178 - accuracy: 0.9953


181/938 [====>.........................] - ETA: 2s - loss: 0.0179 - accuracy: 0.9953


199/938 [=====>........................] - ETA: 2s - loss: 0.0186 - accuracy: 0.9953


217/938 [=====>........................] - ETA: 2s - loss: 0.0185 - accuracy: 0.9952





















































































Learning rate for epoch 10 is 9.999999747378752e-06


Epoch 11/12

  1/938 [..............................] - ETA: 14s - loss: 0.0044 - accuracy: 1.0000


 19/938 [..............................] - ETA: 2s - loss: 0.0139 - accuracy: 0.9967 


 36/938 [>.............................] - ETA: 2s - loss: 0.0243 - accuracy: 0.9948


 54/938 [>.............................] - ETA: 2s - loss: 0.0210 - accuracy: 0.9948


 72/938 [=>............................] - ETA: 2s - loss: 0.0209 - accuracy: 0.9950


 90/938 [=>............................] - ETA: 2s - loss: 0.0182 - accuracy: 0.9958


108/938 [==>...........................] - ETA: 2s - loss: 0.0178 - accuracy: 0.9958


126/938 [===>..........................] - ETA: 2s - loss: 0.0185 - accuracy: 0.9958


144/938 [===>..........................] - ETA: 2s - loss: 0.0178 - accuracy: 0.9960


162/938 [====>.........................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9959


180/938 [====>.........................] - ETA: 2s - loss: 0.0188 - accuracy: 0.9951


198/938 [=====>........................] - ETA: 2s - loss: 0.0187 - accuracy: 0.9949


216/938 [=====>........................] - ETA: 2s - loss: 0.0190 - accuracy: 0.9949





















































































Learning rate for epoch 11 is 9.999999747378752e-06


Epoch 12/12



  1/938 [..............................] - ETA: 15s - loss: 0.0906 - accuracy: 0.9844


 19/938 [..............................] - ETA: 2s - loss: 0.0239 - accuracy: 0.9926 


 37/938 [>.............................] - ETA: 2s - loss: 0.0253 - accuracy: 0.9949


 55/938 [>.............................] - ETA: 2s - loss: 0.0227 - accuracy: 0.9952


 73/938 [=>............................] - ETA: 2s - loss: 0.0211 - accuracy: 0.9959


 91/938 [=>............................] - ETA: 2s - loss: 0.0224 - accuracy: 0.9952


109/938 [==>...........................] - ETA: 2s - loss: 0.0200 - accuracy: 0.9958


127/938 [===>..........................] - ETA: 2s - loss: 0.0190 - accuracy: 0.9958


145/938 [===>..........................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9962


163/938 [====>.........................] - ETA: 2s - loss: 0.0182 - accuracy: 0.9962


181/938 [====>.........................] - ETA: 2s - loss: 0.0180 - accuracy: 0.9961


200/938 [=====>........................] - ETA: 2s - loss: 0.0183 - accuracy: 0.9959


218/938 [=====>........................] - ETA: 2s - loss: 0.0183 - accuracy: 0.9958





















































































Learning rate for epoch 12 is 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7f4e5c176dd0>

Check for saved checkpoints:

In [16]:
# Check the checkpoint directory.
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index		     ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index		     ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index		     ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


To check how well the model performs, load the latest checkpoint and call `Model.evaluate` on the test data:

In [17]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))

2021-08-04 01:25:49.277864: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.



  1/157 [..............................] - ETA: 4:18 - loss: 0.0786 - accuracy: 0.9688


 17/157 [==>...........................] - ETA: 0s - loss: 0.0379 - accuracy: 0.9899  


 34/157 [=====>........................] - ETA: 0s - loss: 0.0319 - accuracy: 0.9903






















Eval loss: 0.03712465986609459, Eval accuracy: 0.987500011920929


To visualize the output, launch TensorBoard and view the logs:

In [None]:
%tensorboard --logdir=logs

<!-- <img class="tfo-display-only-on-site" src="images/tensorboard_distributed_training_with_keras.png"/> -->

In [18]:
!ls -sh ./logs

total 4.0K
4.0K train


## Export to SavedModel

Export the graph and the variables to the platform-agnostic SavedModel format using `Model.save`. After your model is saved, you can load it with or without the `Strategy.scope`.

In [19]:
path = 'saved_model/'

In [20]:
model.save(path, save_format='tf')

2021-08-04 01:25:51.983973: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


Now, load the model without `Strategy.scope`:

In [21]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))


  1/157 [..............................] - ETA: 28s - loss: 0.0786 - accuracy: 0.9688


 27/157 [====>.........................] - ETA: 0s - loss: 0.0309 - accuracy: 0.9907 














Eval loss: 0.03712465986609459, Eval Accuracy: 0.987500011920929


Load the model with `Strategy.scope`:

In [22]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

2021-08-04 01:25:53.544239: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.



  1/157 [..............................] - ETA: 3:27 - loss: 0.0786 - accuracy: 0.9688


 21/157 [===>..........................] - ETA: 0s - loss: 0.0345 - accuracy: 0.9903  
















Eval loss: 0.03712465986609459, Eval Accuracy: 0.987500011920929


### Additional resources

More examples that use different distribution strategies with the Keras `Model.fit` API:

1. The [Solve GLUE tasks using BERT on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial uses `tf.distribute.MirroredStrategy` for training on GPUs and `tf.distribute.TPUStrategy`—on TPUs.
1. The [Save and load a model using a distribution strategy](save_and_load.ipynb) tutorial demonstates how to use the SavedModel APIs with `tf.distribute.Strategy`.
1. The [official TensorFlow models](https://github.com/tensorflow/models/tree/master/official) can be configured to run multiple distribution strategies.

To learn more about TensorFlow distribution strategies:

1. The [Custom training with tf.distribute.Strategy](custom_training.ipynb) tutorial shows how to use the `tf.distribute.MirroredStrategy` for single-worker training with a custom training loop.
1. The [Multi-worker training with Keras](multi_worker_with_keras.ipynb) tutorial shows how to use the `MultiWorkerMirroredStrategy` with `Model.fit`.
1. The [Custom training loop with Keras and MultiWorkerMirroredStrategy](multi_worker_with_ctl.ipynb) tutorial shows how to use the `MultiWorkerMirroredStrategy` with Keras and a custom training loop.
1. The [Distributed training in TensorFlow](https://www.tensorflow.org/guide/distributed_training) guide provides an overview of the available distribution strategies.
1. The [Better performance with tf.function](../../guide/function.ipynb) guide provides information about other strategies and tools, such as the [TensorFlow Profiler](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.

Note: `tf.distribute.Strategy` is actively under development and TensorFlow will be adding more examples and tutorials in the near future. Please give it a try. Your feedback is welcome—feel free to submit it via [issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new).