In [1]:
import numpy as np
import tensorflow as tf
import json
import os
from ray.train import Trainer

import ray

In [2]:
def mnist_dataset(batch_size):
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    # The `x` arrays are in uint8 and have values in the [0, 255] range.
    # You need to convert them to float32 with values in the [0, 1] range.
    x_train = x_train / np.float32(255)
    y_train = y_train.astype(np.int64)
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
    return train_dataset


def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28)),
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        metrics=['accuracy'])
    return model


def train_func_distributed():
    per_worker_batch_size = 64
    # This environment variable will be set by Ray Train.
    tf_config = json.loads(os.environ['TF_CONFIG'])
    num_workers = len(tf_config['cluster']['worker'])

    strategy = tf.distribute.MultiWorkerMirroredStrategy()

    global_batch_size = per_worker_batch_size * num_workers
    multi_worker_dataset = mnist_dataset(global_batch_size)

    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_and_compile_cnn_model()

    multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

In [3]:
ray.init("ray://example-cluster-ray-head-svc:10001")
#ray.init("example-cluster-ray-head-svc.ricliu.svc.local:10001")
#ray.init()

ClientContext(dashboard_url='10.48.0.48:8265', python_version='3.7.7', ray_version='1.13.1', ray_commit='da2a91cd34ac58df4c49b2fa65a5bd25bc1e2057', protocol_version='2022-03-16', _num_clients=2, _context_to_restore=<ray.util.client._ClientContext object at 0x7f77b4966850>)

In [4]:
trainer = Trainer(backend="tensorflow", num_workers=4)
trainer.start()
results = trainer.run(train_func_distributed)
trainer.shutdown()

2022-08-12 22:09:24,101	INFO trainer.py:243 -- Trainer logs will be logged in: /home/jovyan/ray_results/train_2022-08-12_22-09-24
[2m[36m(BackendExecutor pid=185, ip=10.48.1.46)[0m 2022-08-12 15:09:54.398997: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
[2m[36m(BackendExecutor pid=185, ip=10.48.1.46)[0m 2022-08-12 15:09:54.399052: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
[2m[36m(BaseWorkerMixin pid=727)[0m 2022-08-12 15:10:14.910930: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
[2m[36m(BaseWorkerMixin pid=727)[0m 2022-08-12 15:10:14.910967: I tensorflow/stream_executor/cuda/cuda

[2m[36m(BaseWorkerMixin pid=183, ip=10.48.0.49)[0m Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[2m[36m(BaseWorkerMixin pid=727)[0m Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[2m[36m(BaseWorkerMixin pid=184, ip=10.48.1.47)[0m Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
    8192/11490434 [..............................] - ETA: 0s


[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m 2022-08-12 15:10:22.634476: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m op: "TensorSliceDataset"
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m input: "Placeholder/_0"
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m input: "Placeholder/_1"
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m attr {
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m   key: "Toutput_types"
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m   value {
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m     list {
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m       type: DT_FLOAT
[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m       type: DT_INT64
[2m[36

[2m[36m(BaseWorkerMixin pid=245, ip=10.48.1.46)[0m Epoch 1/3
[2m[36m(BaseWorkerMixin pid=184, ip=10.48.1.47)[0m Epoch 1/3
[2m[36m(BaseWorkerMixin pid=183, ip=10.48.0.49)[0m Epoch 1/3
[2m[36m(BaseWorkerMixin pid=727)[0m Epoch 1/3
 1/70 [..............................] - ETA: 5:54 - loss: 2.3070 - accuracy: 0.0977
 1/70 [..............................] - ETA: 5:54 - loss: 2.3070 - accuracy: 0.0977
 1/70 [..............................] - ETA: 6:23 - loss: 2.3070 - accuracy: 0.0977
 1/70 [..............................] - ETA: 6:23 - loss: 2.3070 - accuracy: 0.0977
 2/70 [..............................] - ETA: 18s - loss: 2.3121 - accuracy: 0.1055 
 2/70 [..............................] - ETA: 20s - loss: 2.3121 - accuracy: 0.1055 
 2/70 [..............................] - ETA: 19s - loss: 2.3121 - accuracy: 0.1055 
 2/70 [..............................] - ETA: 18s - loss: 2.3121 - accuracy: 0.1055 
 3/70 [>.............................] - ETA: 16s - loss: 2.3142 - accuracy: 0.

[2m[36m(BaseWorkerMixin pid=183, ip=10.48.0.49)[0m 2022-08-12 15:11:18.740128: E tensorflow/core/common_runtime/base_collective_executor.cc:249] BaseCollectiveExecutor::StartAbort UNAVAILABLE: failed to connect to all addresses
[2m[36m(BaseWorkerMixin pid=183, ip=10.48.0.49)[0m Additional GRPC error information from remote target /job:worker/replica:0/task:0:
[2m[36m(BaseWorkerMixin pid=183, ip=10.48.0.49)[0m :{"created":"@1660342278.739932345","description":"Failed to pick subchannel","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/client_channel.cc","file_line":3940,"referenced_errors":[{"created":"@1660342278.734870307","description":"failed to connect to all addresses","file":"external/com_github_grpc_grpc/src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":392,"grpc_status":14}]}
