diff --git a/tensorflowonspark/TFNode.py b/tensorflowonspark/TFNode.py index b4b5f5d..0a36cd3 100644 --- a/tensorflowonspark/TFNode.py +++ b/tensorflowonspark/TFNode.py @@ -23,7 +23,10 @@ from . import compat, marker logger = logging.getLogger(__name__) -TF_VERSION = pkg_resources.get_distribution('tensorflow').version +try: + TF_VERSION = pkg_resources.get_distribution('tensorflow').version +except pkg_resources.DistributionNotFound: + TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version def hdfs_path(ctx, path): diff --git a/tensorflowonspark/TFParallel.py b/tensorflowonspark/TFParallel.py index aef7c88..bb37d99 100644 --- a/tensorflowonspark/TFParallel.py +++ b/tensorflowonspark/TFParallel.py @@ -47,7 +47,7 @@ def _run(it): nodes = [t.address for t in tasks] num_workers = len(nodes) else: - nodes = None + nodes = [] num_workers = num_executors # use the placement info to help allocate GPUs diff --git a/tensorflowonspark/TFSparkNode.py b/tensorflowonspark/TFSparkNode.py index 995a0f5..54a05cd 100644 --- a/tensorflowonspark/TFSparkNode.py +++ b/tensorflowonspark/TFSparkNode.py @@ -31,7 +31,10 @@ from . import util logger = logging.getLogger(__name__) -TF_VERSION = pkg_resources.get_distribution('tensorflow').version +try: + TF_VERSION = pkg_resources.get_distribution('tensorflow').version +except pkg_resources.DistributionNotFound: + TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version def _has_spark_resource_api(): @@ -502,7 +505,7 @@ def _train(iter): joinThr = Thread(target=queue.join) joinThr.start() timeout = feed_timeout - while (joinThr.isAlive()): + while (joinThr.is_alive()): if (not equeue.empty()): e_str = equeue.get() raise Exception("Exception in worker:\n" + e_str) @@ -570,7 +573,7 @@ def _inference(iter): joinThr = Thread(target=queue_in.join) joinThr.start() timeout = feed_timeout - while (joinThr.isAlive()): + while (joinThr.is_alive()): if (not equeue.empty()): e_str = equeue.get() raise Exception("Exception in worker:\n" + e_str) diff --git a/tensorflowonspark/pipeline.py b/tensorflowonspark/pipeline.py index 159bdda..b3da0ef 100755 --- a/tensorflowonspark/pipeline.py +++ b/tensorflowonspark/pipeline.py @@ -31,7 +31,10 @@ logger = logging.getLogger(__name__) -TF_VERSION = pkg_resources.get_distribution('tensorflow').version +try: + TF_VERSION = pkg_resources.get_distribution('tensorflow').version +except pkg_resources.DistributionNotFound: + TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version # TensorFlowOnSpark Params diff --git a/tensorflowonspark/util.py b/tensorflowonspark/util.py index fb4cc73..875ffa4 100644 --- a/tensorflowonspark/util.py +++ b/tensorflowonspark/util.py @@ -30,7 +30,7 @@ def single_node_env(num_gpus=1, worker_index=-1, nodes=[]): if gpu_info.is_gpu_available() and num_gpus > 0: # reserve GPU(s), if requested - if worker_index >= 0 and len(nodes) > 0: + if worker_index >= 0 and nodes and len(nodes) > 0: # compute my index relative to other nodes on the same host, if known my_addr = nodes[worker_index] my_host = my_addr.split(':')[0] diff --git a/tests/test_TFSparkNode.py b/tests/test_TFSparkNode.py index 12ce8f4..cd823c4 100644 --- a/tests/test_TFSparkNode.py +++ b/tests/test_TFSparkNode.py @@ -46,8 +46,10 @@ def fn(args, ctx): map_fn = TFSparkNode.run(fn, tf_args, self.cluster_meta, self.tensorboard, self.log_dir, self.queues, self.background) map_fn([0]) - def test_gpu_unavailable(self): + @patch('tensorflowonspark.gpu_info.is_gpu_available') + def test_gpu_unavailable(self, mock_available): """Request GPU with no GPUs available, expecting an exception""" + mock_available.return_value = False self.parser.add_argument("--num_gpus", help="number of gpus to use", type=int) tf_args = self.parser.parse_args(["--num_gpus", "1"])