Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from . import compat, marker

logger = logging.getLogger(__name__)
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
try:
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
except pkg_resources.DistributionNotFound:
TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version


def hdfs_path(ctx, path):
Expand Down
2 changes: 1 addition & 1 deletion tensorflowonspark/TFParallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _run(it):
nodes = [t.address for t in tasks]
num_workers = len(nodes)
else:
nodes = None
nodes = []
num_workers = num_executors

# use the placement info to help allocate GPUs
Expand Down
9 changes: 6 additions & 3 deletions tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from . import util

logger = logging.getLogger(__name__)
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
try:
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
except pkg_resources.DistributionNotFound:
TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version


def _has_spark_resource_api():
Expand Down Expand Up @@ -502,7 +505,7 @@ def _train(iter):
joinThr = Thread(target=queue.join)
joinThr.start()
timeout = feed_timeout
while (joinThr.isAlive()):
while (joinThr.is_alive()):
if (not equeue.empty()):
e_str = equeue.get()
raise Exception("Exception in worker:\n" + e_str)
Expand Down Expand Up @@ -570,7 +573,7 @@ def _inference(iter):
joinThr = Thread(target=queue_in.join)
joinThr.start()
timeout = feed_timeout
while (joinThr.isAlive()):
while (joinThr.is_alive()):
if (not equeue.empty()):
e_str = equeue.get()
raise Exception("Exception in worker:\n" + e_str)
Expand Down
5 changes: 4 additions & 1 deletion tensorflowonspark/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@


logger = logging.getLogger(__name__)
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
try:
TF_VERSION = pkg_resources.get_distribution('tensorflow').version
except pkg_resources.DistributionNotFound:
TF_VERSION = pkg_resources.get_distribution('tensorflow-cpu').version


# TensorFlowOnSpark Params
Expand Down
2 changes: 1 addition & 1 deletion tensorflowonspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):

if gpu_info.is_gpu_available() and num_gpus > 0:
# reserve GPU(s), if requested
if worker_index >= 0 and len(nodes) > 0:
if worker_index >= 0 and nodes and len(nodes) > 0:
# compute my index relative to other nodes on the same host, if known
my_addr = nodes[worker_index]
my_host = my_addr.split(':')[0]
Expand Down
4 changes: 3 additions & 1 deletion tests/test_TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def fn(args, ctx):
map_fn = TFSparkNode.run(fn, tf_args, self.cluster_meta, self.tensorboard, self.log_dir, self.queues, self.background)
map_fn([0])

def test_gpu_unavailable(self):
@patch('tensorflowonspark.gpu_info.is_gpu_available')
def test_gpu_unavailable(self, mock_available):
"""Request GPU with no GPUs available, expecting an exception"""
mock_available.return_value = False
self.parser.add_argument("--num_gpus", help="number of gpus to use", type=int)
tf_args = self.parser.parse_args(["--num_gpus", "1"])

Expand Down