In [None]:
import numpy as np
import tensorflow as tf
import json
import os
import requests
import tempfile
import time

from ray.train import Trainer
from ray import serve
import ray

In [None]:
TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")

@serve.deployment(route_prefix="/mnist")
class TFMnistModel:
    def __init__(self, model_path):
        import tensorflow as tf
        self.model_path = model_path
        self.model = tf.keras.models.load_model(model_path)

    async def __call__(self, starlette_request):
        # Step 1: transform HTTP request -> tensorflow input
        # Here we define the request schema to be a json array.
        input_array = np.array((await starlette_request.json())["array"])
        reshaped_array = input_array.reshape((1, 28, 28))

        # Step 2: tensorflow input -> tensorflow output
        prediction = self.model(reshaped_array)

        # Step 3: tensorflow output -> web output
        return {
            "prediction": prediction.numpy().tolist(),
            "file": self.model_path
        }


In [None]:
def mnist_dataset(batch_size):
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    # The `x` arrays are in uint8 and have values in the [0, 255] range.
    # You need to convert them to float32 with values in the [0, 1] range.
    x_train = x_train / np.float32(255)
    y_train = y_train.astype(np.int64)
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
    return x_train, y_train, train_dataset


def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28)),
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        metrics=['accuracy'])
    return model


def train_func_distributed():
    per_worker_batch_size = 64
    # This environment variable will be set by Ray Train.
    tf_config = json.loads(os.environ['TF_CONFIG'])
    num_workers = len(tf_config['cluster']['worker'])

    strategy = tf.distribute.MultiWorkerMirroredStrategy()

    global_batch_size = per_worker_batch_size * num_workers
    x_test, y_test, multi_worker_dataset = mnist_dataset(global_batch_size)

    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_and_compile_cnn_model()

    multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

    #multi_worker_model.evaluate(x_test, y_test, verbose=2)
    multi_worker_model.summary()

    # Save the model in h5 format in local file system
    multi_worker_model.save(TRAINED_MODEL_PATH)

In [None]:
ray.init("ray://example-cluster-head-svc:10001")

In [None]:
trainer = Trainer(backend="tensorflow", num_workers=4) #, use_gpu=True)
trainer.start()
results = trainer.run(train_func_distributed)
trainer.shutdown()

In [None]:
#serve.start(detached=True)
serve.start(detached=True, http_options={"host": "0.0.0.0"})
TFMnistModel.deploy(TRAINED_MODEL_PATH)