Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions tensorflowonspark/TFCluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from . import TFManager
from . import TFSparkNode

logger = logging.getLogger(__name__)

# status of TF background job
tf_status = {}

Expand Down Expand Up @@ -73,7 +75,7 @@ def train(self, dataRDD, num_epochs=0, feed_timeout=600, qname='input'):
:feed_timeout: number of seconds after which data feeding times out (600 sec default)
:qname: *INTERNAL USE*.
"""
logging.info("Feeding training data")
logger.info("Feeding training data")
assert self.input_mode == InputMode.SPARK, "TFCluster.train() requires InputMode.SPARK"
assert qname in self.queues, "Unknown queue: {}".format(qname)
assert num_epochs >= 0, "num_epochs cannot be negative"
Expand Down Expand Up @@ -107,7 +109,7 @@ def inference(self, dataRDD, feed_timeout=600, qname='input'):
Returns:
A Spark RDD representing the output of the TensorFlow inferencing
"""
logging.info("Feeding inference data")
logger.info("Feeding inference data")
assert self.input_mode == InputMode.SPARK, "TFCluster.inference() requires InputMode.SPARK"
assert qname in self.queues, "Unknown queue: {}".format(qname)
return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, feed_timeout=feed_timeout, qname=qname))
Expand All @@ -123,7 +125,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
:grace_secs: Grace period to wait after all executors have completed their tasks before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. Default is 0.
:timeout: Time in seconds to wait for TF cluster to complete before terminating the Spark application. This can be useful if the TF code hangs for any reason. Default is 3 days. Use -1 to disable timeout.
"""
logging.info("Stopping TensorFlow nodes")
logger.info("Waiting for TensorFlow nodes to complete...")

# identify ps/workers
ps_list, worker_list, eval_list = [], [], []
Expand All @@ -133,7 +135,7 @@ def shutdown(self, ssc=None, grace_secs=0, timeout=259200):
# setup execution timeout
if timeout > 0:
def timeout_handler(signum, frame):
logging.error("TensorFlow execution timed out, exiting Spark application with error status")
logger.error("TensorFlow execution timed out, exiting Spark application with error status")
self.sc.cancelAllJobs()
self.sc.stop()
sys.exit(1)
Expand All @@ -146,7 +148,7 @@ def timeout_handler(signum, frame):
# Spark Streaming
while not ssc.awaitTerminationOrTimeout(1):
if self.server.done:
logging.info("Server done, stopping StreamingContext")
logger.info("Server done, stopping StreamingContext")
ssc.stop(stopSparkContext=False, stopGraceFully=True)
break
elif self.input_mode == InputMode.TENSORFLOW:
Expand Down Expand Up @@ -175,12 +177,12 @@ def timeout_handler(signum, frame):

# exit Spark application w/ err status if TF job had any errors
if 'error' in tf_status:
logging.error("Exiting Spark application with error status.")
logger.error("Exiting Spark application with error status.")
self.sc.cancelAllJobs()
self.sc.stop()
sys.exit(1)

logging.info("Shutting down cluster")
logger.info("Shutting down cluster")
# shutdown queues and managers for "PS" executors.
# note: we have to connect/shutdown from the spark driver, because these executors are "busy" and won't accept any other tasks.
for node in ps_list + eval_list:
Expand Down Expand Up @@ -230,7 +232,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
Returns:
A TFCluster object representing the started cluster.
"""
logging.info("Reserving TFSparkNodes {0}".format("w/ TensorBoard" if tensorboard else ""))
logger.info("Reserving TFSparkNodes {0}".format("w/ TensorBoard" if tensorboard else ""))

if driver_ps_nodes and input_mode != InputMode.TENSORFLOW:
raise Exception('running PS nodes on driver locally is only supported in InputMode.TENSORFLOW')
Expand Down Expand Up @@ -263,7 +265,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
if num_workers > 0:
cluster_template['worker'] = executors[:num_workers]

logging.info("cluster_template: {}".format(cluster_template))
logger.info("cluster_template: {}".format(cluster_template))

# get default filesystem from spark
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
Expand All @@ -279,7 +281,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
server_addr = server.start()

# start TF nodes on all executors
logging.info("Starting TensorFlow on executors")
logger.info("Starting TensorFlow on executors")
cluster_meta = {
'id': random.getrandbits(64),
'cluster_template': cluster_template,
Expand All @@ -295,7 +297,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo

if driver_ps_nodes:
def _start_ps(node_index):
logging.info("starting ps node locally %d" % node_index)
logger.info("starting ps node locally %d" % node_index)
TFSparkNode.run(map_fun,
tf_args,
cluster_meta,
Expand All @@ -319,7 +321,7 @@ def _start(status):
queues,
background=(input_mode == InputMode.SPARK)))
except Exception as e:
logging.error("Exception in TF background thread")
logger.error("Exception in TF background thread")
status['error'] = str(e)

t = threading.Thread(target=_start, args=(tf_status,))
Expand All @@ -329,23 +331,23 @@ def _start(status):
t.start()

# wait for executors to register and start TFNodes before continuing
logging.info("Waiting for TFSparkNodes to start")
logger.info("Waiting for TFSparkNodes to start")
cluster_info = server.await_reservations(sc, tf_status, reservation_timeout)
logging.info("All TFSparkNodes started")
logger.info("All TFSparkNodes started")

# print cluster_info and extract TensorBoard URL
tb_url = None
for node in cluster_info:
logging.info(node)
logger.info(node)
if node['tb_port'] != 0:
tb_url = "http://{0}:{1}".format(node['host'], node['tb_port'])

if tb_url is not None:
logging.info("========================================================================================")
logging.info("")
logging.info("TensorBoard running at: {0}".format(tb_url))
logging.info("")
logging.info("========================================================================================")
logger.info("========================================================================================")
logger.info("")
logger.info("TensorBoard running at: {0}".format(tb_url))
logger.info("")
logger.info("========================================================================================")

# since our "primary key" for each executor's TFManager is (host, executor_id), sanity check for duplicates
# Note: this may occur if Spark retries failed Python tasks on the same executor.
Expand Down
19 changes: 10 additions & 9 deletions tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from six.moves.queue import Empty
from . import marker

logger = logging.getLogger(__name__)

def hdfs_path(ctx, path):
"""Convenience function to create a Tensorflow-compatible absolute HDFS path from relative paths
Expand Down Expand Up @@ -54,7 +55,7 @@ def hdfs_path(ctx, path):
elif ctx.defaultFS.startswith("file://"):
return "{0}/{1}/{2}".format(ctx.defaultFS, ctx.working_dir[1:], path)
else:
logging.warn("Unknown scheme {0} with relative path: {1}".format(ctx.defaultFS, path))
logger.warn("Unknown scheme {0} with relative path: {1}".format(ctx.defaultFS, path))
return "{0}/{1}".format(ctx.defaultFS, path)


Expand Down Expand Up @@ -120,21 +121,21 @@ def next_batch(self, batch_size):
Returns:
A batch of items or a dictionary of tensors.
"""
logging.debug("next_batch() invoked")
logger.debug("next_batch() invoked")
queue = self.mgr.get_queue(self.qname_in)
tensors = [] if self.input_tensors is None else {tensor: [] for tensor in self.input_tensors}
count = 0
while count < batch_size:
item = queue.get(block=True)
if item is None:
# End of Feed
logging.info("next_batch() got None")
logger.info("next_batch() got None")
queue.task_done()
self.done_feeding = True
break
elif type(item) is marker.EndPartition:
# End of Partition
logging.info("next_batch() got EndPartition")
logger.info("next_batch() got EndPartition")
queue.task_done()
if not self.train_mode and count > 0:
break
Expand All @@ -147,7 +148,7 @@ def next_batch(self, batch_size):
tensors[self.input_tensors[i]].append(item[i])
count += 1
queue.task_done()
logging.debug("next_batch() returning {0} items".format(count))
logger.debug("next_batch() returning {0} items".format(count))
return tensors

def should_stop(self):
Expand All @@ -163,11 +164,11 @@ def batch_results(self, results):
Args:
:results: array of output data for the equivalent batch of input data.
"""
logging.debug("batch_results() invoked")
logger.debug("batch_results() invoked")
queue = self.mgr.get_queue(self.qname_out)
for item in results:
queue.put(item, block=True)
logging.debug("batch_results() returning data")
logger.debug("batch_results() returning data")

def terminate(self):
"""Terminate data feeding early.
Expand All @@ -177,7 +178,7 @@ def terminate(self):
to terminate an RDD operation early, so the extra partitions will still be sent to the executors (but will be ignored). Because
of this, you should size your input data accordingly to avoid excessive overhead.
"""
logging.info("terminate() invoked")
logger.info("terminate() invoked")
self.mgr.set('state', 'terminating')

# drop remaining items in the queue
Expand All @@ -190,5 +191,5 @@ def terminate(self):
queue.task_done()
count += 1
except Empty:
logging.info("dropped {0} items from queue".format(count))
logger.info("dropped {0} items from queue".format(count))
done = True
Loading