diff --git a/docs/source/tensorflowonspark.TFParallel.rst b/docs/source/tensorflowonspark.TFParallel.rst new file mode 100644 index 00000000..c29d00ab --- /dev/null +++ b/docs/source/tensorflowonspark.TFParallel.rst @@ -0,0 +1,7 @@ +tensorflowonspark\.TFParallel module +=================================== + +.. automodule:: tensorflowonspark.TFParallel + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/tensorflowonspark.rst b/docs/source/tensorflowonspark.rst index b99e9178..809a9503 100644 --- a/docs/source/tensorflowonspark.rst +++ b/docs/source/tensorflowonspark.rst @@ -14,6 +14,7 @@ Submodules tensorflowonspark.TFCluster tensorflowonspark.TFManager tensorflowonspark.TFNode + tensorflowonspark.TFParallel tensorflowonspark.TFSparkNode tensorflowonspark.dfutil tensorflowonspark.gpu_info diff --git a/examples/mnist/keras/mnist_inference.py b/examples/mnist/keras/mnist_inference.py index af444d50..21df737a 100644 --- a/examples/mnist/keras/mnist_inference.py +++ b/examples/mnist/keras/mnist_inference.py @@ -21,16 +21,7 @@ import tensorflow as tf -def inference(it, num_workers, args): - from tensorflowonspark import util - - # consume worker number from RDD partition iterator - for i in it: - worker_num = i - print("worker_num: {}".format(i)) - - # setup env for single-node TF - util.single_node_env() +def inference(args, ctx): # load saved_model saved_model = tf.saved_model.load(args.export_dir, tags='serve') @@ -48,14 +39,14 @@ def parse_tfr(example_proto): # define a new tf.data.Dataset (for inferencing) ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) - ds = ds.shard(num_workers, worker_num) + ds = ds.shard(ctx.num_workers, ctx.worker_num) ds = ds.interleave(tf.data.TFRecordDataset) ds = ds.map(parse_tfr) ds = ds.batch(10) # create an output file per spark worker for the predictions tf.io.gfile.makedirs(args.output) - output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w') + output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, ctx.worker_num), mode='w') for batch in ds: predictions = predict(conv2d_input=batch[0]) @@ -70,6 +61,7 @@ def parse_tfr(example_proto): if __name__ == '__main__': from pyspark.context import SparkContext from pyspark.conf import SparkConf + from tensorflowonspark import TFParallel sc = SparkContext(conf=SparkConf().setAppName("mnist_inference")) executors = sc._conf.get("spark.executor.instances") @@ -83,7 +75,5 @@ def parse_tfr(example_proto): args, _ = parser.parse_known_args() print("args: {}".format(args)) - # Not using TFCluster... just running single-node TF instances on each executor - nodes = list(range(args.cluster_size)) - nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size) - nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args)) + # Running single-node TF instances on each executor + TFParallel.run(sc, inference, args, args.cluster_size) diff --git a/examples/resnet/README.md b/examples/resnet/README.md index 13ea30af..acd111d1 100644 --- a/examples/resnet/README.md +++ b/examples/resnet/README.md @@ -4,14 +4,16 @@ Original Source: https://github.com/tensorflow/models/tree/master/official/visio This code is based on the Image Classification model from the official [TensorFlow Models](https://github.com/tensorflow/models) repository. This example already supports different forms of distribution via the `DistributionStrategy` API, so there isn't much additional work to convert it to TensorFlowOnSpark. -Notes: +Notes: - This example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed. - For simplicity, this just uses a single-node Spark Standalone installation. #### Run the Single-Node Application -First, make sure that you can run the example per the [original instructions](https://github.com/tensorflow/models/tree/68c3c65596b8fc624be15aef6eac3dc8952cbf23/official/vision/image_classification). For now, we'll just use the CIFAR-10 dataset. After cloning the [tensorflow/models](https://github.com/tensorflow/models) repository and downloading the dataset, you should be able to run the training as follows: +First, make sure that you can run the example per the [original instructions](https://github.com/tensorflow/models/tree/68c3c65596b8fc624be15aef6eac3dc8952cbf23/official/vision/image_classification). For now, we'll just use the CIFAR-10 dataset. After cloning the [tensorflow/models](https://github.com/tensorflow/models) repository (checking out the `v2.0` tag with `git checkout v2.0`), and downloading the dataset, you should be able to run the training as follows: ``` +# Note: these instructions have been tested with the `v2.0` tag of tensorflow/models. + export TENSORFLOW_MODELS=/path/to/tensorflow/models export CIFAR_DATA=/path/to/cifar export PYTHONPATH=${PYTHONPATH}:${TENSORFLOW_MODELS} diff --git a/tensorflowonspark/TFNode.py b/tensorflowonspark/TFNode.py index 131c1260..5b96de38 100644 --- a/tensorflowonspark/TFNode.py +++ b/tensorflowonspark/TFNode.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) + def hdfs_path(ctx, path): """Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths diff --git a/tensorflowonspark/TFParallel.py b/tensorflowonspark/TFParallel.py new file mode 100644 index 00000000..414949d3 --- /dev/null +++ b/tensorflowonspark/TFParallel.py @@ -0,0 +1,64 @@ +# Copyright 2019 Yahoo Inc / Verizon Media +# Licensed under the terms of the Apache 2.0 license. +# Please see LICENSE file in the project root for terms. + +from __future__ import absolute_import +from __future__ import division +from __future__ import nested_scopes +from __future__ import print_function + +import logging +from . import TFSparkNode +from . import gpu_info, util + +logger = logging.getLogger(__name__) + + +def run(sc, map_fn, tf_args, num_executors): + """Runs the user map_fn as parallel, independent instances of TF on the Spark executors. + + Args: + :sc: SparkContext + :map_fun: user-supplied TensorFlow "main" function + :tf_args: ``argparse`` args, or command-line ``ARGV``. These will be passed to the ``map_fun``. + :num_executors: number of Spark executors. This should match your Spark job's ``--num_executors``. + + Returns: + None + """ + + # get default filesystem from spark + defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS") + # strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..." + if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"): + defaultFS = defaultFS[:-1] + + def _run(it): + from pyspark import BarrierTaskContext + + for i in it: + worker_num = i + + # use BarrierTaskContext to get placement of all nodes + ctx = BarrierTaskContext.get() + tasks = ctx.getTaskInfos() + nodes = [t.address for t in tasks] + + # use the placement info to help allocate GPUs + num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1 + util.single_node_env(num_gpus=num_gpus, worker_index=worker_num, nodes=nodes) + + # run the user map_fn + ctx = TFSparkNode.TFNodeContext() + ctx.defaultFS = defaultFS + ctx.worker_num = worker_num + ctx.executor_id = worker_num + ctx.num_workers = len(nodes) + + map_fn(tf_args, ctx) + + # return a dummy iterator (since we have to use mapPartitions) + return [0] + + nodeRDD = sc.parallelize(list(range(num_executors)), num_executors) + nodeRDD.barrier().mapPartitions(_run).collect() diff --git a/tensorflowonspark/TFSparkNode.py b/tensorflowonspark/TFSparkNode.py index e3ab3948..08d48d67 100755 --- a/tensorflowonspark/TFSparkNode.py +++ b/tensorflowonspark/TFSparkNode.py @@ -46,7 +46,7 @@ class TFNodeContext: :working_dir: the current working directory for local filesystems, or YARN containers. :mgr: TFManager instance for this Python worker. """ - def __init__(self, executor_id, job_name, task_index, cluster_spec, defaultFS, working_dir, mgr): + def __init__(self, executor_id=0, job_name='', task_index=0, cluster_spec={}, defaultFS='file://', working_dir='.', mgr=None): self.worker_num = executor_id # for backwards-compatibility self.executor_id = executor_id self.job_name = job_name diff --git a/tensorflowonspark/reservation.py b/tensorflowonspark/reservation.py index 94a61835..c6060cad 100644 --- a/tensorflowonspark/reservation.py +++ b/tensorflowonspark/reservation.py @@ -190,9 +190,8 @@ def _listen(self, sock): def get_server_ip(self): return os.getenv(TFOS_SERVER_HOST) if os.getenv(TFOS_SERVER_HOST) else util.get_ip_address() - def start_listening_socket(self): - port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0 + port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0 server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_sock.bind(('', port_number)) diff --git a/tensorflowonspark/util.py b/tensorflowonspark/util.py index 38d1d117..38fab892 100644 --- a/tensorflowonspark/util.py +++ b/tensorflowonspark/util.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -def single_node_env(num_gpus=1): +def single_node_env(num_gpus=1, worker_index=-1, nodes=[]): """Setup environment variables for Hadoop compatibility and GPU allocation""" import tensorflow as tf # ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI) @@ -29,9 +29,19 @@ def single_node_env(num_gpus=1): os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath os.environ['TFOS_CLASSPATH_UPDATED'] = '1' - # reserve GPU, if requested - if tf.test.is_built_with_cuda(): - gpus_to_use = gpu_info.get_gpus(num_gpus) + if tf.test.is_built_with_cuda() and num_gpus > 0: + # reserve GPU(s), if requested + if worker_index >= 0 and len(nodes) > 0: + # compute my index relative to other nodes on the same host, if known + my_addr = nodes[worker_index] + my_host = my_addr.split(':')[0] + local_peers = [n for n in nodes if n.startswith(my_host)] + my_index = local_peers.index(my_addr) + else: + # otherwise, just use global worker index + my_index = worker_index + + gpus_to_use = gpu_info.get_gpus(num_gpus, my_index) logger.info("Using gpu(s): {0}".format(gpus_to_use)) os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use else: