From 3cce0e1c88c6e7dea9d3d0cd454a13bd4f8c5cd8 Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Fri, 8 Mar 2019 13:28:06 -0800 Subject: [PATCH] add back code to terminate feed --- examples/mnist/spark/mnist_dist.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/mnist/spark/mnist_dist.py b/examples/mnist/spark/mnist_dist.py index d870e407..cdf0a8c2 100755 --- a/examples/mnist/spark/mnist_dist.py +++ b/examples/mnist/spark/mnist_dist.py @@ -117,7 +117,7 @@ def rdd_generator(): if args.mode == "train": _, summary, step = sess.run([train_op, summary_op, global_step]) - if (step % 100 == 0): + if (step % 100 == 0) and (not sess.should_stop()): print("{} step: {} accuracy: {}".format(datetime.now().isoformat(), step, sess.run(accuracy))) if task_index == 0: summary_writer.add_summary(summary, step) @@ -129,6 +129,9 @@ def rdd_generator(): print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat())) + if sess.should_stop() or step >= args.steps: + tf_feed.terminate() + # WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745 # wait for all other nodes to complete (via done files) done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode)