diff --git a/examples/mnist/spark/mnist_dist.py b/examples/mnist/spark/mnist_dist.py index d870e407..cdf0a8c2 100755 --- a/examples/mnist/spark/mnist_dist.py +++ b/examples/mnist/spark/mnist_dist.py @@ -117,7 +117,7 @@ def rdd_generator(): if args.mode == "train": _, summary, step = sess.run([train_op, summary_op, global_step]) - if (step % 100 == 0): + if (step % 100 == 0) and (not sess.should_stop()): print("{} step: {} accuracy: {}".format(datetime.now().isoformat(), step, sess.run(accuracy))) if task_index == 0: summary_writer.add_summary(summary, step) @@ -129,6 +129,9 @@ def rdd_generator(): print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat())) + if sess.should_stop() or step >= args.steps: + tf_feed.terminate() + # WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745 # wait for all other nodes to complete (via done files) done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode)