Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from packaging import version
from six.moves.queue import Empty
from . import marker
from . import compat, marker

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,7 +90,7 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
cluster_spec = ctx.cluster_spec
logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec))

if tf.test.is_built_with_cuda() and num_gpus > 0:
if compat.is_gpu_available() and num_gpus > 0:
# compute my index relative to other nodes placed on the same host (for GPU allocation)
my_addr = cluster_spec[ctx.job_name][ctx.task_index]
my_host = my_addr.split(':')[0]
Expand Down
5 changes: 3 additions & 2 deletions tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from . import TFManager
from . import TFNode
from . import compat
from . import gpu_info
from . import marker
from . import reservation
Expand Down Expand Up @@ -144,7 +145,7 @@ def _mapfn(iter):
executor_id = i

# check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node
if tf.test.is_built_with_cuda():
if compat.is_gpu_available():
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
gpus_to_use = gpu_info.get_gpus(num_gpus)

Expand Down Expand Up @@ -295,7 +296,7 @@ def _mapfn(iter):
os.environ['TF_CONFIG'] = tf_config

# reserve GPU(s) again, just before launching TF process (in case situation has changed)
if tf.test.is_built_with_cuda():
if compat.is_gpu_available():
# compute my index relative to other nodes on the same host (for GPU allocation)
my_addr = cluster_spec[job_name][task_index]
my_host = my_addr.split(':')[0]
Expand Down
7 changes: 7 additions & 0 deletions tensorflowonspark/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@ def disable_auto_shard(options):
options.experimental_distribute.auto_shard = False
else:
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF


def is_gpu_available():
if version.parse(tf.__version__) < version.parse('2.1.0'):
return tf.test.is_built_with_cuda()
else:
return len(tf.config.list_logical_devices('GPU')) > 0
5 changes: 2 additions & 3 deletions tensorflowonspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
import subprocess
import errno
from socket import error as socket_error
from . import gpu_info
from . import compat, gpu_info

logger = logging.getLogger(__name__)


def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
"""Setup environment variables for Hadoop compatibility and GPU allocation"""
import tensorflow as tf
# ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI)
if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ:
classpath = os.environ['CLASSPATH']
Expand All @@ -29,7 +28,7 @@ def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'

if tf.test.is_built_with_cuda() and num_gpus > 0:
if compat.is_gpu_available() and num_gpus > 0:
# reserve GPU(s), if requested
if worker_index >= 0 and len(nodes) > 0:
# compute my index relative to other nodes on the same host, if known
Expand Down
1 change: 1 addition & 0 deletions test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _spark_train(args, ctx):
import tensorflow as tf
from tensorflowonspark import TFNode

tf.compat.v1.disable_eager_execution()
tf.compat.v1.reset_default_graph()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

Expand Down