diff --git a/examples/mnist/keras/mnist_mlp_estimator.py b/examples/mnist/keras/mnist_mlp_estimator.py index 7f3329b6..570dad4b 100644 --- a/examples/mnist/keras/mnist_mlp_estimator.py +++ b/examples/mnist/keras/mnist_mlp_estimator.py @@ -104,6 +104,22 @@ def predict_input_fn(): for result in predictions: tf_feed.batch_results([result]) + # WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745 + # wait for all other nodes to complete (via done files) + done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode) + print("Writing done file to: {}".format(done_dir)) + tf.gfile.MakeDirs(done_dir) + with tf.gfile.GFile("{}/{}".format(done_dir, ctx.task_index), 'w') as done_file: + done_file.write("done") + + for i in range(60): + if len(tf.gfile.ListDirectory(done_dir)) < len(ctx.cluster_spec['worker']): + print("{} Waiting for other nodes {}".format(datetime.now().isoformat(), i)) + time.sleep(1) + else: + print("{} All nodes done".format(datetime.now().isoformat())) + break + if __name__ == '__main__': import argparse diff --git a/pom.xml b/pom.xml index 5bcf35e8..99f3d86e 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ 4.0.0 com.yahoo.ml tensorflowonspark - 1.0 + 1.0.1 jar tensorflowonspark Spark Scala inferencing for TensorFlowOnSpark @@ -25,7 +25,7 @@ 3.5.3 2.1 2.20.1 - 2.1.0 + [2.2.0,) 2.11.8 3.2.1 1.1.0