In [1]:
import tensorflow as tf

# Helper libraries
import numpy as np
import os

In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

In [3]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [4]:
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 2


In [28]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 16
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 1

In [29]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

In [30]:
def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model

In [31]:
with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)



In [32]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

In [33]:
# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [34]:
with strategy.scope():
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(labels, predictions)
    return loss, predictions 

  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

In [35]:
with strategy.scope():
  # `run` replicates the provided computation and runs it
  # with the distributed input.
  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses, per_replica_predictions = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None), per_replica_predictions
 
  @tf.function
  def distributed_test_step(dataset_inputs):
    return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))

  for epoch in range(EPOCHS):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in train_dist_dataset:
      loss, per_replica_predictions = distributed_train_step(x)
      total_loss += loss
      num_batches += 1
    train_loss = total_loss / num_batches

    # TEST LOOP
    for x in test_dist_dataset:
      distributed_test_step(x)


    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print (template.format(epoch+1, train_loss,
                           train_accuracy.result()*100, test_loss.result(),
                           test_accuracy.result()*100))

    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.4464421570301056, Accuracy: 83.72000122070312, Test Loss: 0.3529491126537323, Test Accuracy: 87.55000305175781


In [41]:
tf.concat(strategy.unwrap(per_replica_predictions), axis=0)

<tf.Tensor: shape=(32, 10), dtype=float32, numpy=
array([[  6.7312727 ,  -1.6091244 ,  -0.06497218,   0.8819937 ,
         -2.9817572 ,  -4.9757714 ,   3.5792334 ,  -6.100062  ,
         -1.09286   ,  -4.037592  ],
       [ -1.2379279 ,  -3.938323  ,   6.472806  ,  -0.41371906,
          5.9494867 ,  -6.394777  ,   1.2372278 ,  -8.913586  ,
         -0.04357997,  -8.745478  ],
       [ -6.4974494 , -14.12088   ,  -8.516809  ,  -3.5889118 ,
         -7.969909  ,   5.0643444 ,  -6.79579   ,  10.441473  ,
          0.1867151 ,   5.8311043 ],
       [ -1.9673445 ,   0.48762375,  -3.1698112 ,   5.6242065 ,
         -2.4876537 ,  -0.35391805,  -2.1810849 ,  -3.108228  ,
         -1.7661748 ,  -2.3138983 ],
       [ -5.747141  ,  -9.4875965 ,  -7.071219  ,  -6.3860908 ,
         -8.004895  ,   2.742964  ,  -3.3828218 ,   4.1335053 ,
          0.22667333,  12.409351  ],
       [ -7.5411143 , -10.38132   ,  -8.168939  ,  -7.113113  ,
         -6.584156  ,   4.0573215 ,  -2.9296384 ,   5.668868 

In [39]:
strategy.unwrap(per_replica_predictions)[1].shape

TensorShape([16, 10])