In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os, json

In [2]:
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
})

In [3]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

INFO:tensorflow:Enabled multi-worker collective ops with available devices: ['/job:worker/replica:0/task:0/device:CPU:0', '/job:worker/replica:0/task:0/device:XLA_CPU:0', '/job:worker/replica:0/task:0/device:GPU:0', '/job:worker/replica:0/task:0/device:XLA_GPU:0']
INFO:tensorflow:Using MirroredStrategy with devices ('/job:worker/task:0/device:GPU:0',)
INFO:tensorflow:MultiWorkerMirroredStrategy with cluster_spec = {'worker': ['localhost:12345', 'localhost:23456']}, task_type = 'worker', task_id = 0, num_workers = 2, local_devices = ('/job:worker/task:0/device:GPU:0',), communication = CollectiveCommunication.AUTO


In [4]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

datasets, info = tfds.load(name='mnist', with_info = True, as_supervised = True)
train_datasets_unbatched = datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)
train_datasets = train_datasets_unbatched.batch(BATCH_SIZE)

In [5]:
def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation = 'relu', input_shape=(28,28,1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation = 'relu'),
        tf.keras.layers.Dense(10, activation = 'softmax')
    ])
    model.compile(
        loss = tf.keras.losses.sparse_categorical_crossentropy,
        optimizer = tf.keras.optimizers.SGD(learning_rate=0.001),
        metrics = ['accuracy']
    )
    
    return model

In [6]:
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x = train_datasets, epochs = 3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x1e846124c88>

In [None]:
NUM_WORKERS = 2
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE)
with strategy.scope():
    multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3)

In [None]:
options = tf.data.Options()
options.experimental_distribute.auto_shard=False
train_datasets_no_auto_shard = train_datasets.with_options(options)

In [None]:
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
    multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3, callbacks=callbacks)