From 7ae5925e23362a716f7d8fbcb5a4f863061f2a28 Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Thu, 15 Nov 2018 15:42:42 -0800 Subject: [PATCH 1/2] more deterministic GPU scheduling; add timeout for entire Spark application to handle TF hangs --- tensorflowonspark/TFCluster.py | 20 ++++++++++++++++++-- tensorflowonspark/TFNode.py | 9 ++++++++- tensorflowonspark/TFSparkNode.py | 28 ++++++++++++++++++---------- tensorflowonspark/gpu_info.py | 28 ++++++++++++++++++++-------- 4 files changed, 64 insertions(+), 21 deletions(-) diff --git a/tensorflowonspark/TFCluster.py b/tensorflowonspark/TFCluster.py index e37d7530..d68baf2c 100644 --- a/tensorflowonspark/TFCluster.py +++ b/tensorflowonspark/TFCluster.py @@ -25,6 +25,7 @@ import logging import os import random +import signal import sys import threading import time @@ -111,12 +112,16 @@ def inference(self, dataRDD, feed_timeout=600, qname='input'): assert qname in self.queues, "Unknown queue: {}".format(qname) return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, feed_timeout=feed_timeout, qname=qname)) - def shutdown(self, ssc=None, grace_secs=0): + def shutdown(self, ssc=None, grace_secs=0, timeout=259200): """Stops the distributed TensorFlow cluster. + For InputMode.SPARK, this will be executed AFTER the `TFCluster.train()` or `TFCluster.inference()` method completes. + For InputMode.TENSORFLOW, this will be executed IMMEDIATELY after `TFCluster.run()` and will wait until the TF worker nodes complete. + Args: :ssc: *For Streaming applications only*. Spark StreamingContext - :grace_secs: Grace period to wait before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. + :grace_secs: Grace period to wait after all executors have completed their tasks before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. Default is 0. + :timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. """ logging.info("Stopping TensorFlow nodes") @@ -125,6 +130,17 @@ def shutdown(self, ssc=None, grace_secs=0): for node in self.cluster_info: (ps_list if node['job_name'] == 'ps' else worker_list).append(node) + # setup execution timeout + def timeout_handler(signum, frame): + logging.error("TensorFlow execution timed out, exiting Spark application with error status") + self.sc.cancelAllJobs() + self.sc.stop() + sys.exit(1) + + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout) + + # wait for Spark Streaming termination or TF app completion for InputMode.TENSORFLOW if ssc is not None: # Spark Streaming while not ssc.awaitTerminationOrTimeout(1): diff --git a/tensorflowonspark/TFNode.py b/tensorflowonspark/TFNode.py index 889ff52e..93094d7a 100755 --- a/tensorflowonspark/TFNode.py +++ b/tensorflowonspark/TFNode.py @@ -82,6 +82,13 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False): logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec)) if tf.test.is_built_with_cuda() and num_gpus > 0: + # compute my index relative to other nodes placed on the same host (for GPU allocation) + my_addr = cluster_spec[ctx.job_name][ctx.task_index] + my_host = my_addr.split(':')[0] + flattened = [v for sublist in cluster_spec.values() for v in sublist] + local_peers = [p for p in flattened if p.startswith(my_host)] + my_index = local_peers.index(my_addr) + # GPU gpu_initialized = False retries = 3 @@ -92,7 +99,7 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False): num_gpus = 1 # Find a free gpu(s) to use - gpus_to_use = gpu_info.get_gpus(num_gpus) + gpus_to_use = gpu_info.get_gpus(num_gpus, my_index) gpu_prompt = "GPU" if num_gpus == 1 else "GPUs" logging.info("{0}: Using {1}: {2}".format(ctx.worker_num, gpu_prompt, gpus_to_use)) diff --git a/tensorflowonspark/TFSparkNode.py b/tensorflowonspark/TFSparkNode.py index f1d6ebf4..d9d6f084 100644 --- a/tensorflowonspark/TFSparkNode.py +++ b/tensorflowonspark/TFSparkNode.py @@ -139,9 +139,10 @@ def _mapfn(iter): for i in iter: executor_id = i - # run quick check of GPU infrastructure if using tensorflow-gpu + # check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node if tf.test.is_built_with_cuda(): - gpus_to_use = gpu_info.get_gpus(1) + num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1 + gpus_to_use = gpu_info.get_gpus(num_gpus) # assign TF job/task based on provided cluster_spec template (or use default/null values) job_name = 'default' @@ -261,7 +262,7 @@ def _mapfn(iter): # construct a TensorFlow clusterspec from cluster_info sorted_cluster_info = sorted(cluster_info, key=lambda k: k['executor_id']) - spec = {} + cluster_spec = {} last_executor_id = -1 for node in sorted_cluster_info: if (node['executor_id'] == last_executor_id): @@ -269,30 +270,37 @@ def _mapfn(iter): last_executor_id = node['executor_id'] logging.info("node: {0}".format(node)) (njob, nhost, nport) = (node['job_name'], node['host'], node['port']) - hosts = [] if njob not in spec else spec[njob] + hosts = [] if njob not in cluster_spec else cluster_spec[njob] hosts.append("{0}:{1}".format(nhost, nport)) - spec[njob] = hosts + cluster_spec[njob] = hosts # update TF_CONFIG if cluster spec has a 'master' node (i.e. tf.estimator) - if 'master' in spec: + if 'master' in cluster_spec: tf_config = json.dumps({ - 'cluster': spec, + 'cluster': cluster_spec, 'task': {'type': job_name, 'index': task_index}, 'environment': 'cloud' }) logging.info("export TF_CONFIG: {}".format(tf_config)) os.environ['TF_CONFIG'] = tf_config - # reserve GPU + # reserve GPU(s) again, just before launching TF process (in case situation has changed) if tf.test.is_built_with_cuda(): + # compute my index relative to other nodes on the same host (for GPU allocation) + my_addr = cluster_spec[job_name][task_index] + my_host = my_addr.split(':')[0] + flattened = [v for sublist in cluster_spec.values() for v in sublist] + local_peers = [p for p in flattened if p.startswith(my_host)] + my_index = local_peers.index(my_addr) + num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1 - gpus_to_use = gpu_info.get_gpus(num_gpus) + gpus_to_use = gpu_info.get_gpus(num_gpus, my_index) gpu_str = "GPUs" if num_gpus > 1 else "GPU" logging.debug("Requested {} {}, setting CUDA_VISIBLE_DEVICES={}".format(num_gpus, gpu_str, gpus_to_use)) os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use # create a context object to hold metadata for TF - ctx = TFNodeContext(executor_id, job_name, task_index, spec, cluster_meta['default_fs'], cluster_meta['working_dir'], TFSparkNode.mgr) + ctx = TFNodeContext(executor_id, job_name, task_index, cluster_spec, cluster_meta['default_fs'], cluster_meta['working_dir'], TFSparkNode.mgr) # release port reserved for TF as late as possible if tmp_sock is not None: diff --git a/tensorflowonspark/gpu_info.py b/tensorflowonspark/gpu_info.py index 8ef6dd8d..ffc7ae7a 100644 --- a/tensorflowonspark/gpu_info.py +++ b/tensorflowonspark/gpu_info.py @@ -40,13 +40,14 @@ def _get_gpu(): return gpu -def get_gpus(num_gpu=1): +def get_gpus(num_gpu=1, worker_index=-1): """Get list of free GPUs according to nvidia-smi. This will retry for ``MAX_RETRIES`` times until the requested number of GPUs are available. Args: :num_gpu: number of GPUs desired. + :worker_index: index "hint" for allocation of available GPUs. Returns: Comma-delimited string of GPU ids, or raises an Exception if the requested number of GPUs could not be found. @@ -63,9 +64,6 @@ def parse_gpu(gpu_str): return cols[5].split(')')[0], cols[1].split(':')[0] gpu_list = [parse_gpu(gpu) for gpu in gpus] - # randomize the search order to get a better distribution of GPUs - random.shuffle(gpu_list) - free_gpus = [] retries = 0 while len(free_gpus) < num_gpu and retries < MAX_RETRIES: @@ -77,19 +75,33 @@ def parse_gpu(gpu_str): free_gpus.append(index) if len(free_gpus) < num_gpu: - # keep trying indefinitely logging.warn("Unable to find available GPUs: requested={0}, available={1}".format(num_gpu, len(free_gpus))) retries += 1 time.sleep(30 * retries) free_gpus = [] - # if still can't find GPUs, raise exception + logging.info("Available GPUs: {}".format(free_gpus)) + + # if still can't find available GPUs, raise exception if len(free_gpus) < num_gpu: smi_output = subprocess.check_output(["nvidia-smi", "--format=csv", "--query-compute-apps=gpu_uuid,pid,process_name,used_gpu_memory"]).decode() logging.info(": {0}".format(smi_output)) - raise Exception("Unable to find free GPU:\n{0}".format(smi_output)) + raise Exception("Unable to find {} free GPU(s)\n{}".format(num_gpu, smi_output)) + + # Get logical placement + num_available = len(free_gpus) + if worker_index == -1: + # use original random placement + random.shuffle(free_gpus) + proposed_gpus = free_gpus[:num_gpu] + else: + # ordered by worker index + if worker_index + num_gpu > num_available: + worker_index = worker_index % num_available + proposed_gpus = free_gpus[worker_index:(worker_index + num_gpu)] + logging.info("Proposed GPUs: {}".format(proposed_gpus)) - return ','.join(free_gpus[:num_gpu]) + return ','.join(str(x) for x in proposed_gpus) # Function to get the gpu information From 3824995dc6fdb14f8382671a4b167f5e42c44bd4 Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Thu, 15 Nov 2018 16:08:52 -0800 Subject: [PATCH 2/2] add option to disable timeout --- tensorflowonspark/TFCluster.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tensorflowonspark/TFCluster.py b/tensorflowonspark/TFCluster.py index d68baf2c..32e63378 100644 --- a/tensorflowonspark/TFCluster.py +++ b/tensorflowonspark/TFCluster.py @@ -121,7 +121,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200): Args: :ssc: *For Streaming applications only*. Spark StreamingContext :grace_secs: Grace period to wait after all executors have completed their tasks before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. Default is 0. - :timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. + :timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. Use -1 to disable timeout. """ logging.info("Stopping TensorFlow nodes") @@ -131,14 +131,15 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200): (ps_list if node['job_name'] == 'ps' else worker_list).append(node) # setup execution timeout - def timeout_handler(signum, frame): - logging.error("TensorFlow execution timed out, exiting Spark application with error status") - self.sc.cancelAllJobs() - self.sc.stop() - sys.exit(1) - - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(timeout) + if timeout > 0: + def timeout_handler(signum, frame): + logging.error("TensorFlow execution timed out, exiting Spark application with error status") + self.sc.cancelAllJobs() + self.sc.stop() + sys.exit(1) + + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout) # wait for Spark Streaming termination or TF app completion for InputMode.TENSORFLOW if ssc is not None: