diff --git a/tensorflowonspark/TFNode.py b/tensorflowonspark/TFNode.py index 2997481e..0e94b75a 100755 --- a/tensorflowonspark/TFNode.py +++ b/tensorflowonspark/TFNode.py @@ -70,7 +70,7 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False): cluster_spec = ctx.cluster_spec logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec)) - if tf.test.is_built_with_cuda(): + if tf.test.is_built_with_cuda() and num_gpus > 0: # GPU gpu_initialized = False retries = 3