diff --git a/examples/mnist/spark/mnist_dist.py b/examples/mnist/spark/mnist_dist.py index 031b2dc1..85dd51cf 100755 --- a/examples/mnist/spark/mnist_dist.py +++ b/examples/mnist/spark/mnist_dist.py @@ -9,16 +9,40 @@ from __future__ import nested_scopes from __future__ import print_function +from datetime import datetime +import tensorflow as tf +from tensorflowonspark import TFNode + def print_log(worker_num, arg): print("{0}: {1}".format(worker_num, arg)) +class ExportHook(tf.train.SessionRunHook): + def __init__(self, export_dir, input_tensor, output_tensor): + self.export_dir = export_dir + self.input_tensor = input_tensor + self.output_tensor = output_tensor + + def end(self, session): + print("{} ======= Exporting to: {}".format(datetime.now().isoformat(), self.export_dir)) + signatures = { + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: { + 'inputs': {'image': self.input_tensor}, + 'outputs': {'prediction': self.output_tensor}, + 'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME + } + } + TFNode.export_saved_model(session, + self.export_dir, + tf.saved_model.tag_constants.SERVING, + signatures) + print("{} ======= Done exporting".format(datetime.now().isoformat())) + + def map_fun(args, ctx): - from datetime import datetime import math import numpy - import tensorflow as tf import time worker_num = ctx.worker_num @@ -105,7 +129,6 @@ def feed_dict(batch): logdir = ctx.absolute_path(args.model) print("tensorflow model path: {0}".format(logdir)) - hooks = [tf.train.StopAtStepHook(last_step=100000)] if job_name == "worker" and task_index == 0: summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph()) @@ -115,11 +138,11 @@ def feed_dict(batch): with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), checkpoint_dir=logdir, - hooks=hooks) as mon_sess: - + hooks=[tf.train.StopAtStepHook(last_step=args.steps)], + chief_only_hooks=[ExportHook(ctx.absolute_path(args.export_dir), x, prediction)]) as mon_sess: step = 0 tf_feed = ctx.get_data_feed(args.mode == "train") - while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps: + while not mon_sess.should_stop() and not tf_feed.should_stop(): # Run a training step asynchronously # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. diff --git a/examples/mnist/spark/mnist_spark.py b/examples/mnist/spark/mnist_spark.py index 4b7e119e..9c6a4415 100755 --- a/examples/mnist/spark/mnist_spark.py +++ b/examples/mnist/spark/mnist_spark.py @@ -25,6 +25,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100) parser.add_argument("--epochs", help="number of epochs", type=int, default=1) +parser.add_argument("--export_dir", help="HDFS path to export saved_model", default="mnist_export") parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="csv") parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format") parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format") @@ -71,6 +72,7 @@ def toNumpy(bytestr): else: labelRDD = cluster.inference(dataRDD) labelRDD.saveAsTextFile(args.output) -cluster.shutdown() + +cluster.shutdown(grace_secs=30) print("{0} ===== Stop".format(datetime.now().isoformat())) diff --git a/tensorflowonspark/TFCluster.py b/tensorflowonspark/TFCluster.py index f21232c9..ffd0abbd 100644 --- a/tensorflowonspark/TFCluster.py +++ b/tensorflowonspark/TFCluster.py @@ -109,11 +109,12 @@ def inference(self, dataRDD, qname='input'): assert(qname in self.queues) return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, qname)) - def shutdown(self, ssc=None): + def shutdown(self, ssc=None, grace_secs=0): """Stops the distributed TensorFlow cluster. Args: :ssc: *For Streaming applications only*. Spark StreamingContext + :grace_secs: Grace period to wait before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model. """ logging.info("Stopping TensorFlow nodes") @@ -146,12 +147,13 @@ def shutdown(self, ssc=None): count += 1 time.sleep(5) - # shutdown queues and managers for "worker" executors. - # note: in SPARK mode, this job will immediately queue up behind the "data feeding" job. - # in TENSORFLOW mode, this will only run after all workers have finished. - workers = len(worker_list) - workerRDD = self.sc.parallelize(range(workers), workers) - workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues)) + # shutdown queues and managers for "worker" executors. + # note: in SPARK mode, this job will immediately queue up behind the "data feeding" job. + # in TENSORFLOW mode, this will only run after all workers have finished. + workers = len(worker_list) + workerRDD = self.sc.parallelize(range(workers), workers) + workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues)) + time.sleep(grace_secs) # exit Spark application w/ err status if TF job had any errors if 'error' in tf_status: