Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions tensorflowonspark/TFCluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import logging
import os
import random
import signal
import sys
import threading
import time
Expand Down Expand Up @@ -111,12 +112,16 @@ def inference(self, dataRDD, feed_timeout=600, qname='input'):
assert qname in self.queues, "Unknown queue: {}".format(qname)
return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, feed_timeout=feed_timeout, qname=qname))

def shutdown(self, ssc=None, grace_secs=0):
def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
"""Stops the distributed TensorFlow cluster.

For InputMode.SPARK, this will be executed AFTER the `TFCluster.train()` or `TFCluster.inference()` method completes.
For InputMode.TENSORFLOW, this will be executed IMMEDIATELY after `TFCluster.run()` and will wait until the TF worker nodes complete.

Args:
:ssc: *For Streaming applications only*. Spark StreamingContext
:grace_secs: Grace period to wait before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model.
:grace_secs: Grace period to wait after all executors have completed their tasks before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. Default is 0.
:timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. Use -1 to disable timeout.
"""
logging.info("Stopping TensorFlow nodes")

Expand All @@ -125,6 +130,18 @@ def shutdown(self, ssc=None, grace_secs=0):
for node in self.cluster_info:
(ps_list if node['job_name'] == 'ps' else worker_list).append(node)

# setup execution timeout
if timeout > 0:
def timeout_handler(signum, frame):
logging.error("TensorFlow execution timed out, exiting Spark application with error status")
self.sc.cancelAllJobs()
self.sc.stop()
sys.exit(1)

signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)

# wait for Spark Streaming termination or TF app completion for InputMode.TENSORFLOW
if ssc is not None:
# Spark Streaming
while not ssc.awaitTerminationOrTimeout(1):
Expand Down
9 changes: 8 additions & 1 deletion tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec))

if tf.test.is_built_with_cuda() and num_gpus > 0:
# compute my index relative to other nodes placed on the same host (for GPU allocation)
my_addr = cluster_spec[ctx.job_name][ctx.task_index]
my_host = my_addr.split(':')[0]
flattened = [v for sublist in cluster_spec.values() for v in sublist]
local_peers = [p for p in flattened if p.startswith(my_host)]
my_index = local_peers.index(my_addr)

# GPU
gpu_initialized = False
retries = 3
Expand All @@ -92,7 +99,7 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
num_gpus = 1

# Find a free gpu(s) to use
gpus_to_use = gpu_info.get_gpus(num_gpus)
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
gpu_prompt = "GPU" if num_gpus == 1 else "GPUs"
logging.info("{0}: Using {1}: {2}".format(ctx.worker_num, gpu_prompt, gpus_to_use))

Expand Down
28 changes: 18 additions & 10 deletions tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def _mapfn(iter):
for i in iter:
executor_id = i

# run quick check of GPU infrastructure if using tensorflow-gpu
# check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node
if tf.test.is_built_with_cuda():
gpus_to_use = gpu_info.get_gpus(1)
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
gpus_to_use = gpu_info.get_gpus(num_gpus)

# assign TF job/task based on provided cluster_spec template (or use default/null values)
job_name = 'default'
Expand Down Expand Up @@ -261,38 +262,45 @@ def _mapfn(iter):

# construct a TensorFlow clusterspec from cluster_info
sorted_cluster_info = sorted(cluster_info, key=lambda k: k['executor_id'])
spec = {}
cluster_spec = {}
last_executor_id = -1
for node in sorted_cluster_info:
if (node['executor_id'] == last_executor_id):
raise Exception("Duplicate worker/task in cluster_info")
last_executor_id = node['executor_id']
logging.info("node: {0}".format(node))
(njob, nhost, nport) = (node['job_name'], node['host'], node['port'])
hosts = [] if njob not in spec else spec[njob]
hosts = [] if njob not in cluster_spec else cluster_spec[njob]
hosts.append("{0}:{1}".format(nhost, nport))
spec[njob] = hosts
cluster_spec[njob] = hosts

# update TF_CONFIG if cluster spec has a 'master' node (i.e. tf.estimator)
if 'master' in spec:
if 'master' in cluster_spec:
tf_config = json.dumps({
'cluster': spec,
'cluster': cluster_spec,
'task': {'type': job_name, 'index': task_index},
'environment': 'cloud'
})
logging.info("export TF_CONFIG: {}".format(tf_config))
os.environ['TF_CONFIG'] = tf_config

# reserve GPU
# reserve GPU(s) again, just before launching TF process (in case situation has changed)
if tf.test.is_built_with_cuda():
# compute my index relative to other nodes on the same host (for GPU allocation)
my_addr = cluster_spec[job_name][task_index]
my_host = my_addr.split(':')[0]
flattened = [v for sublist in cluster_spec.values() for v in sublist]
local_peers = [p for p in flattened if p.startswith(my_host)]
my_index = local_peers.index(my_addr)

num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
gpus_to_use = gpu_info.get_gpus(num_gpus)
gpus_to_use = gpu_info.get_gpus(num_gpus, my_index)
gpu_str = "GPUs" if num_gpus > 1 else "GPU"
logging.debug("Requested {} {}, setting CUDA_VISIBLE_DEVICES={}".format(num_gpus, gpu_str, gpus_to_use))
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use

# create a context object to hold metadata for TF
ctx = TFNodeContext(executor_id, job_name, task_index, spec, cluster_meta['default_fs'], cluster_meta['working_dir'], TFSparkNode.mgr)
ctx = TFNodeContext(executor_id, job_name, task_index, cluster_spec, cluster_meta['default_fs'], cluster_meta['working_dir'], TFSparkNode.mgr)

# release port reserved for TF as late as possible
if tmp_sock is not None:
Expand Down
28 changes: 20 additions & 8 deletions tensorflowonspark/gpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ def _get_gpu():
return gpu


def get_gpus(num_gpu=1):
def get_gpus(num_gpu=1, worker_index=-1):
"""Get list of free GPUs according to nvidia-smi.

This will retry for ``MAX_RETRIES`` times until the requested number of GPUs are available.

Args:
:num_gpu: number of GPUs desired.
:worker_index: index "hint" for allocation of available GPUs.

Returns:
Comma-delimited string of GPU ids, or raises an Exception if the requested number of GPUs could not be found.
Expand All @@ -63,9 +64,6 @@ def parse_gpu(gpu_str):
return cols[5].split(')')[0], cols[1].split(':')[0]
gpu_list = [parse_gpu(gpu) for gpu in gpus]

# randomize the search order to get a better distribution of GPUs
random.shuffle(gpu_list)

free_gpus = []
retries = 0
while len(free_gpus) < num_gpu and retries < MAX_RETRIES:
Expand All @@ -77,19 +75,33 @@ def parse_gpu(gpu_str):
free_gpus.append(index)

if len(free_gpus) < num_gpu:
# keep trying indefinitely
logging.warn("Unable to find available GPUs: requested={0}, available={1}".format(num_gpu, len(free_gpus)))
retries += 1
time.sleep(30 * retries)
free_gpus = []

# if still can't find GPUs, raise exception
logging.info("Available GPUs: {}".format(free_gpus))

# if still can't find available GPUs, raise exception
if len(free_gpus) < num_gpu:
smi_output = subprocess.check_output(["nvidia-smi", "--format=csv", "--query-compute-apps=gpu_uuid,pid,process_name,used_gpu_memory"]).decode()
logging.info(": {0}".format(smi_output))
raise Exception("Unable to find free GPU:\n{0}".format(smi_output))
raise Exception("Unable to find {} free GPU(s)\n{}".format(num_gpu, smi_output))

# Get logical placement
num_available = len(free_gpus)
if worker_index == -1:
# use original random placement
random.shuffle(free_gpus)
proposed_gpus = free_gpus[:num_gpu]
else:
# ordered by worker index
if worker_index + num_gpu > num_available:
worker_index = worker_index % num_available
proposed_gpus = free_gpus[worker_index:(worker_index + num_gpu)]
logging.info("Proposed GPUs: {}".format(proposed_gpus))

return ','.join(free_gpus[:num_gpu])
return ','.join(str(x) for x in proposed_gpus)


# Function to get the gpu information
Expand Down