Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/mnist/estimator/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# MNIST using Estimator

Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator
Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator

This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark.
This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark.

Note: this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.

Expand Down
3 changes: 1 addition & 2 deletions examples/mnist/estimator/mnist_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def input_fn(mode, input_context=None):
ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1])))
return ds.batch(BATCH_SIZE)
else:
raise Exception("I'm evaluating: mode={}, input_context={}".format(mode, input_context))

# read evaluation data from tensorflow_datasets directly
def scale(image, label):
image = tf.cast(image, tf.float32) / 255.0
return image, label
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist/estimator/mnist_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def main_fun(args, ctx):
import tensorflow_datasets as tfds
from tensorflowonspark import TFNode

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

tfds.disable_progress_bar()

class StopFeedHook(tf.estimator.SessionRunHook):
Expand Down Expand Up @@ -91,7 +93,6 @@ def model_fn(features, labels, mode):
train_op=optimizer.minimize(
loss, tf.compat.v1.train.get_or_create_global_step()))

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100)

classifier = tf.estimator.Estimator(
Expand Down
6 changes: 3 additions & 3 deletions examples/mnist/keras/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# MNIST using Keras

Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras
Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras

This is the [Multi-worker Training with Keras](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark.
This is the [Multi-worker Training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark.

Notes:
- This example assumes that Spark, TensorFlow, TensorFlow Datasets, and TensorFlowOnSpark are already installed.
Expand Down Expand Up @@ -130,7 +130,7 @@ For batch inferencing use cases, you can use Spark to run multiple single-node T
${TFoS_HOME}/examples/mnist/keras/mnist_inference.py \
--cluster_size ${SPARK_WORKER_INSTANCES} \
--images_labels ${TFoS_HOME}/data/mnist/tfr/test \
--export_dir ${TFoS_HOME}/mnist_export \
--export_dir ${SAVED_MODEL} \
--output ${TFoS_HOME}/predictions

#### Train and Inference via Spark ML Pipeline API
Expand Down
10 changes: 4 additions & 6 deletions examples/mnist/keras/mnist_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def main_fun(args, ctx):
import numpy as np
import tensorflow as tf
from tensorflowonspark import TFNode
from tensorflowonspark import compat, TFNode

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

Expand Down Expand Up @@ -65,11 +65,9 @@ def rdd_generator():

multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks)

if ctx.job_name == 'chief':
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
# multi_worker_model.save(args.model_dir, save_format='tf')
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')

# terminating feed tells spark to skip processing further partitions
tf_feed.terminate()
Expand Down
10 changes: 4 additions & 6 deletions examples/mnist/keras/mnist_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def main_fun(args, ctx):
import numpy as np
import tensorflow as tf
from tensorflowonspark import TFNode
from tensorflowonspark import compat, TFNode

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

Expand Down Expand Up @@ -65,11 +65,9 @@ def rdd_generator():

multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks)

if ctx.job_name == 'chief':
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
# multi_worker_model.save(args.model_dir, save_format='tf')
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')

# terminating feed tells spark to skip processing further partitions
tf_feed.terminate()
Expand Down
10 changes: 5 additions & 5 deletions examples/mnist/keras/mnist_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
def main_fun(args, ctx):
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflowonspark import compat

tfds.disable_progress_bar()

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
Expand Down Expand Up @@ -60,11 +62,9 @@ def build_and_compile_cnn_model():
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch, callbacks=callbacks)

if ctx.job_name == 'chief':
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
# multi_worker_model.save(args.model_dir, save_format='tf')
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')


if __name__ == '__main__':
Expand Down
9 changes: 4 additions & 5 deletions examples/mnist/keras/mnist_tf_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
def main_fun(args, ctx):
"""Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets."""
import tensorflow as tf
from tensorflowonspark import compat

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

Expand Down Expand Up @@ -86,11 +87,9 @@ def build_and_compile_cnn_model():
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)

if ctx.job_name == 'chief':
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
tf.keras.experimental.export_saved_model(multi_worker_model, export_dir)
# multi_worker_model.save(args.model_dir, save_format='tf')
from tensorflow_estimator.python.estimator.export import export_lib
export_dir = export_lib.get_timestamped_export_dir(args.export_dir)
compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief')


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions examples/segmentation/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Image Segmentation

Original Source: https://www.tensorflow.org/beta/tutorials/images/segmentation
Original Source: https://www.tensorflow.org/tutorials/images/segmentation

This code is based on the [Image Segmentation](https://www.tensorflow.org/beta/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage.
This code is based on the [Image Segmentation](https://www.tensorflow.org/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage.

Notes:
- this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.
Expand Down
17 changes: 7 additions & 10 deletions examples/segmentation/segmentation_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,15 @@ def unet_model(output_channels):
validation_steps=VALIDATION_STEPS,
validation_data=test_dataset)

if ctx.job_name == 'chief':
if tf.__version__ == '2.0.0':
# Workaround for: https://github.com/tensorflow/tensorflow/issues/30251
print("===== saving h5py model")
model.save(args.model_dir + ".h5")
print("===== re-loading model w/o DistributionStrategy")
new_model = tf.keras.models.load_model(args.model_dir + ".h5")
print("===== exporting saved_model")
tf.keras.experimental.export_saved_model(new_model, args.export_dir)
print("===== done exporting")
# Save model locally as h5py and reload it w/o distribution strategy
if ctx.job_name == 'chief':
model.save(args.model_dir + ".h5")
new_model = tf.keras.models.load_model(args.model_dir + ".h5")
tf.keras.experimental.export_saved_model(new_model, args.export_dir)
else:
print("===== sleeping")
time.sleep(90)
model.save(args.export_dir, save_format='tf')


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
h5py>=2.9.0
numpy>=1.14.0
packaging
py4j==0.10.7
pyspark
scipy
Expand Down
135 changes: 131 additions & 4 deletions tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import getpass
import logging

from packaging import version
from six.moves.queue import Empty
from . import marker

Expand Down Expand Up @@ -61,8 +63,86 @@ def hdfs_path(ctx, path):


def start_cluster_server(ctx, num_gpus=1, rdma=False):
"""*DEPRECATED*. Use higher-level APIs like `tf.keras` or `tf.estimator`"""
raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`")
"""Function that wraps the creation of TensorFlow ``tf.train.Server`` for a node in a distributed TensorFlow cluster.

This is intended to be invoked from within the TF ``map_fun``, replacing explicit code to instantiate ``tf.train.ClusterSpec``
and ``tf.train.Server`` objects.

DEPRECATED for TensorFlow 2.x+

Args:
:ctx: TFNodeContext containing the metadata specific to this node in the cluster.
:num_gpu: number of GPUs desired
:rdma: boolean indicating if RDMA 'iverbs' should be used for cluster communications.

Returns:
A tuple of (cluster_spec, server)
"""
import os
import tensorflow as tf
import time
from . import gpu_info

if version.parse(tf.__version__) >= version.parse("2.0.0"):
raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`")

logging.info("{0}: ======== {1}:{2} ========".format(ctx.worker_num, ctx.job_name, ctx.task_index))
cluster_spec = ctx.cluster_spec
logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec))

if tf.test.is_built_with_cuda() and num_gpus > 0:
# compute my index relative to other nodes placed on the same host (for GPU allocation)
my_addr = cluster_spec[ctx.job_name][ctx.task_index]
my_host = my_addr.split(':')[0]
flattened = [v for sublist in cluster_spec.values() for v in sublist]
local_peers = [p for p in flattened if p.startswith(my_host)]
my_index = local_peers.index(my_addr)

# GPU
gpu_initialized = False
retries = 3
while not gpu_initialized and retries > 0:
try:
# override PS jobs to only reserve one GPU
if ctx.job_name == 'ps':
num_gpus = 0

# Find a free gpu(s) to use
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
gpu_prompt = "GPU" if num_gpus == 1 else "GPUs"
logging.info("{0}: Using {1}: {2}".format(ctx.worker_num, gpu_prompt, gpus_to_use))

# Set GPU device to use for TensorFlow
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use

# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec(cluster_spec)

# Create and start a server for the local task.
if rdma:
server = tf.train.Server(cluster, ctx.job_name, ctx.task_index, protocol="grpc+verbs")
else:
server = tf.train.Server(cluster, ctx.job_name, ctx.task_index)
gpu_initialized = True
except Exception as e:
print(e)
logging.error("{0}: Failed to allocate GPU, trying again...".format(ctx.worker_num))
retries -= 1
time.sleep(10)
if not gpu_initialized:
raise Exception("Failed to allocate GPU")
else:
# CPU
os.environ['CUDA_VISIBLE_DEVICES'] = ''
logging.info("{0}: Using CPU".format(ctx.worker_num))

# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec(cluster_spec)

# Create and start a server for the local task.
server = tf.train.Server(cluster, ctx.job_name, ctx.task_index)

return (cluster, server)


def next_batch(mgr, batch_size, qname='input'):
Expand All @@ -71,8 +151,55 @@ def next_batch(mgr, batch_size, qname='input'):


def export_saved_model(sess, export_dir, tag_set, signatures):
"""*DEPRECATED*. Use TF provided APIs instead."""
raise Exception("DEPRECATED: Use TF provided APIs instead.")
"""Convenience function to export a saved_model using provided arguments

The caller specifies the saved_model signatures in a simplified python dictionary form, as follows::

signatures = {
'signature_def_key': {
'inputs': { 'input_tensor_alias': input_tensor_name },
'outputs': { 'output_tensor_alias': output_tensor_name },
'method_name': 'method'
}
}

And this function will generate the `signature_def_map` and export the saved_model.

DEPRECATED for TensorFlow 2.x+.

Args:
:sess: a tf.Session instance
:export_dir: path to save exported saved_model
:tag_set: string tag_set to identify the exported graph
:signatures: simplified dictionary representation of a TensorFlow signature_def_map

Returns:
A saved_model exported to disk at ``export_dir``.
"""
import tensorflow as tf

if version.parse(tf.__version__) >= version.parse("2.0.0"):
raise Exception("DEPRECATED: Use TF provided APIs instead.")

g = sess.graph
g._unsafe_unfinalize() # https://github.com/tensorflow/serving/issues/363
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

logging.info("===== signatures: {}".format(signatures))
signature_def_map = {}
for key, sig in signatures.items():
signature_def_map[key] = tf.saved_model.signature_def_utils.build_signature_def(
inputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['inputs'].items()},
outputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['outputs'].items()},
method_name=sig['method_name'] if 'method_name' in sig else key)
logging.info("===== signature_def_map: {}".format(signature_def_map))
builder.add_meta_graph_and_variables(
sess,
tag_set.split(','),
signature_def_map=signature_def_map,
clear_devices=True)
g.finalize()
builder.save()


def batch_results(mgr, results, qname='output'):
Expand Down
23 changes: 23 additions & 0 deletions tensorflowonspark/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2019 Yahoo Inc / Verizon Media
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
"""Helper functions to abstract API changes between TensorFlow versions."""

import tensorflow as tf

TF_VERSION = tf.__version__


def export_saved_model(model, export_dir, is_chief=False):
if TF_VERSION == '2.0.0':
if is_chief:
tf.keras.experimental.export_saved_model(model, export_dir)
else:
model.save(export_dir, save_format='tf')


def disable_auto_shard(options):
if TF_VERSION == '2.0.0':
options.experimental_distribute.auto_shard = False
else:
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
Loading