Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/tensorflowonspark.TFParallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
tensorflowonspark\.TFParallel module
===================================

.. automodule:: tensorflowonspark.TFParallel
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/tensorflowonspark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Submodules
tensorflowonspark.TFCluster
tensorflowonspark.TFManager
tensorflowonspark.TFNode
tensorflowonspark.TFParallel
tensorflowonspark.TFSparkNode
tensorflowonspark.dfutil
tensorflowonspark.gpu_info
Expand Down
22 changes: 6 additions & 16 deletions examples/mnist/keras/mnist_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,7 @@
import tensorflow as tf


def inference(it, num_workers, args):
from tensorflowonspark import util

# consume worker number from RDD partition iterator
for i in it:
worker_num = i
print("worker_num: {}".format(i))

# setup env for single-node TF
util.single_node_env()
def inference(args, ctx):

# load saved_model
saved_model = tf.saved_model.load(args.export_dir, tags='serve')
Expand All @@ -48,14 +39,14 @@ def parse_tfr(example_proto):

# define a new tf.data.Dataset (for inferencing)
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
ds = ds.shard(num_workers, worker_num)
ds = ds.shard(ctx.num_workers, ctx.worker_num)
ds = ds.interleave(tf.data.TFRecordDataset)
ds = ds.map(parse_tfr)
ds = ds.batch(10)

# create an output file per spark worker for the predictions
tf.io.gfile.makedirs(args.output)
output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w')
output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, ctx.worker_num), mode='w')

for batch in ds:
predictions = predict(conv2d_input=batch[0])
Expand All @@ -70,6 +61,7 @@ def parse_tfr(example_proto):
if __name__ == '__main__':
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from tensorflowonspark import TFParallel

sc = SparkContext(conf=SparkConf().setAppName("mnist_inference"))
executors = sc._conf.get("spark.executor.instances")
Expand All @@ -83,7 +75,5 @@ def parse_tfr(example_proto):
args, _ = parser.parse_known_args()
print("args: {}".format(args))

# Not using TFCluster... just running single-node TF instances on each executor
nodes = list(range(args.cluster_size))
nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size)
nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args))
# Running single-node TF instances on each executor
TFParallel.run(sc, inference, args, args.cluster_size)
6 changes: 4 additions & 2 deletions examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ Original Source: https://github.com/tensorflow/models/tree/master/official/visio

This code is based on the Image Classification model from the official [TensorFlow Models](https://github.com/tensorflow/models) repository. This example already supports different forms of distribution via the `DistributionStrategy` API, so there isn't much additional work to convert it to TensorFlowOnSpark.

Notes:
Notes:
- This example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed.
- For simplicity, this just uses a single-node Spark Standalone installation.

#### Run the Single-Node Application

First, make sure that you can run the example per the [original instructions](https://github.com/tensorflow/models/tree/68c3c65596b8fc624be15aef6eac3dc8952cbf23/official/vision/image_classification). For now, we'll just use the CIFAR-10 dataset. After cloning the [tensorflow/models](https://github.com/tensorflow/models) repository and downloading the dataset, you should be able to run the training as follows:
First, make sure that you can run the example per the [original instructions](https://github.com/tensorflow/models/tree/68c3c65596b8fc624be15aef6eac3dc8952cbf23/official/vision/image_classification). For now, we'll just use the CIFAR-10 dataset. After cloning the [tensorflow/models](https://github.com/tensorflow/models) repository (checking out the `v2.0` tag with `git checkout v2.0`), and downloading the dataset, you should be able to run the training as follows:
```
# Note: these instructions have been tested with the `v2.0` tag of tensorflow/models.

export TENSORFLOW_MODELS=/path/to/tensorflow/models
export CIFAR_DATA=/path/to/cifar
export PYTHONPATH=${PYTHONPATH}:${TENSORFLOW_MODELS}
Expand Down
1 change: 1 addition & 0 deletions tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

logger = logging.getLogger(__name__)


def hdfs_path(ctx, path):
"""Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths

Expand Down
64 changes: 64 additions & 0 deletions tensorflowonspark/TFParallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2019 Yahoo Inc / Verizon Media
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.

from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function

import logging
from . import TFSparkNode
from . import gpu_info, util

logger = logging.getLogger(__name__)


def run(sc, map_fn, tf_args, num_executors):
"""Runs the user map_fn as parallel, independent instances of TF on the Spark executors.

Args:
:sc: SparkContext
:map_fun: user-supplied TensorFlow "main" function
:tf_args: ``argparse`` args, or command-line ``ARGV``. These will be passed to the ``map_fun``.
:num_executors: number of Spark executors. This should match your Spark job's ``--num_executors``.

Returns:
None
"""

# get default filesystem from spark
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
# strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..."
if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"):
defaultFS = defaultFS[:-1]

def _run(it):
from pyspark import BarrierTaskContext

for i in it:
worker_num = i

# use BarrierTaskContext to get placement of all nodes
ctx = BarrierTaskContext.get()
tasks = ctx.getTaskInfos()
nodes = [t.address for t in tasks]

# use the placement info to help allocate GPUs
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
util.single_node_env(num_gpus=num_gpus, worker_index=worker_num, nodes=nodes)

# run the user map_fn
ctx = TFSparkNode.TFNodeContext()
ctx.defaultFS = defaultFS
ctx.worker_num = worker_num
ctx.executor_id = worker_num
ctx.num_workers = len(nodes)

map_fn(tf_args, ctx)

# return a dummy iterator (since we have to use mapPartitions)
return [0]

nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)
nodeRDD.barrier().mapPartitions(_run).collect()
2 changes: 1 addition & 1 deletion tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TFNodeContext:
:working_dir: the current working directory for local filesystems, or YARN containers.
:mgr: TFManager instance for this Python worker.
"""
def __init__(self, executor_id, job_name, task_index, cluster_spec, defaultFS, working_dir, mgr):
def __init__(self, executor_id=0, job_name='', task_index=0, cluster_spec={}, defaultFS='file://', working_dir='.', mgr=None):
self.worker_num = executor_id # for backwards-compatibility
self.executor_id = executor_id
self.job_name = job_name
Expand Down
3 changes: 1 addition & 2 deletions tensorflowonspark/reservation.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@ def _listen(self, sock):
def get_server_ip(self):
return os.getenv(TFOS_SERVER_HOST) if os.getenv(TFOS_SERVER_HOST) else util.get_ip_address()


def start_listening_socket(self):
port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0
port_number = int(os.getenv(TFOS_SERVER_PORT)) if os.getenv(TFOS_SERVER_PORT) else 0
server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind(('', port_number))
Expand Down
18 changes: 14 additions & 4 deletions tensorflowonspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
logger = logging.getLogger(__name__)


def single_node_env(num_gpus=1):
def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
"""Setup environment variables for Hadoop compatibility and GPU allocation"""
import tensorflow as tf
# ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI)
Expand All @@ -29,9 +29,19 @@ def single_node_env(num_gpus=1):
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'

# reserve GPU, if requested
if tf.test.is_built_with_cuda():
gpus_to_use = gpu_info.get_gpus(num_gpus)
if tf.test.is_built_with_cuda() and num_gpus > 0:
# reserve GPU(s), if requested
if worker_index >= 0 and len(nodes) > 0:
# compute my index relative to other nodes on the same host, if known
my_addr = nodes[worker_index]
my_host = my_addr.split(':')[0]
local_peers = [n for n in nodes if n.startswith(my_host)]
my_index = local_peers.index(my_addr)
else:
# otherwise, just use global worker index
my_index = worker_index

gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
logger.info("Using gpu(s): {0}".format(gpus_to_use))
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use
else:
Expand Down