Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions examples/mnist/spark/mnist_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,40 @@
from __future__ import nested_scopes
from __future__ import print_function

from datetime import datetime
import tensorflow as tf
from tensorflowonspark import TFNode


def print_log(worker_num, arg):
print("{0}: {1}".format(worker_num, arg))


class ExportHook(tf.train.SessionRunHook):
def __init__(self, export_dir, input_tensor, output_tensor):
self.export_dir = export_dir
self.input_tensor = input_tensor
self.output_tensor = output_tensor

def end(self, session):
print("{} ======= Exporting to: {}".format(datetime.now().isoformat(), self.export_dir))
signatures = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
'inputs': {'image': self.input_tensor},
'outputs': {'prediction': self.output_tensor},
'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
}
}
TFNode.export_saved_model(session,
self.export_dir,
tf.saved_model.tag_constants.SERVING,
signatures)
print("{} ======= Done exporting".format(datetime.now().isoformat()))


def map_fun(args, ctx):
from datetime import datetime
import math
import numpy
import tensorflow as tf
import time

worker_num = ctx.worker_num
Expand Down Expand Up @@ -105,7 +129,6 @@ def feed_dict(batch):

logdir = ctx.absolute_path(args.model)
print("tensorflow model path: {0}".format(logdir))
hooks = [tf.train.StopAtStepHook(last_step=100000)]

if job_name == "worker" and task_index == 0:
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())
Expand All @@ -115,11 +138,11 @@ def feed_dict(batch):
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task_index == 0),
checkpoint_dir=logdir,
hooks=hooks) as mon_sess:

hooks=[tf.train.StopAtStepHook(last_step=args.steps)],
chief_only_hooks=[ExportHook(ctx.absolute_path(args.export_dir), x, prediction)]) as mon_sess:
step = 0
tf_feed = ctx.get_data_feed(args.mode == "train")
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
while not mon_sess.should_stop() and not tf_feed.should_stop():
# Run a training step asynchronously
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
Expand Down
4 changes: 3 additions & 1 deletion examples/mnist/spark/mnist_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--export_dir", help="HDFS path to export saved_model", default="mnist_export")
parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv", "pickle", "tfr"], default="csv")
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
Expand Down Expand Up @@ -71,6 +72,7 @@ def toNumpy(bytestr):
else:
labelRDD = cluster.inference(dataRDD)
labelRDD.saveAsTextFile(args.output)
cluster.shutdown()

cluster.shutdown(grace_secs=30)

print("{0} ===== Stop".format(datetime.now().isoformat()))
16 changes: 9 additions & 7 deletions tensorflowonspark/TFCluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ def inference(self, dataRDD, qname='input'):
assert(qname in self.queues)
return dataRDD.mapPartitions(TFSparkNode.inference(self.cluster_info, qname))

def shutdown(self, ssc=None):
def shutdown(self, ssc=None, grace_secs=0):
"""Stops the distributed TensorFlow cluster.

Args:
:ssc: *For Streaming applications only*. Spark StreamingContext
:grace_secs: Grace period to wait before terminating the Spark application, e.g. to allow the chief worker to perform any final/cleanup duties like exporting or evaluating the model.
"""
logging.info("Stopping TensorFlow nodes")

Expand Down Expand Up @@ -146,12 +147,13 @@ def shutdown(self, ssc=None):
count += 1
time.sleep(5)

# shutdown queues and managers for "worker" executors.
# note: in SPARK mode, this job will immediately queue up behind the "data feeding" job.
# in TENSORFLOW mode, this will only run after all workers have finished.
workers = len(worker_list)
workerRDD = self.sc.parallelize(range(workers), workers)
workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues))
# shutdown queues and managers for "worker" executors.
# note: in SPARK mode, this job will immediately queue up behind the "data feeding" job.
# in TENSORFLOW mode, this will only run after all workers have finished.
workers = len(worker_list)
workerRDD = self.sc.parallelize(range(workers), workers)
workerRDD.foreachPartition(TFSparkNode.shutdown(self.cluster_info, self.queues))
time.sleep(grace_secs)

# exit Spark application w/ err status if TF job had any errors
if 'error' in tf_status:
Expand Down