diff --git a/examples/mnist/keras/mnist_mlp_estimator.py b/examples/mnist/keras/mnist_mlp_estimator.py
index 7f3329b6..570dad4b 100644
--- a/examples/mnist/keras/mnist_mlp_estimator.py
+++ b/examples/mnist/keras/mnist_mlp_estimator.py
@@ -104,6 +104,22 @@ def predict_input_fn():
for result in predictions:
tf_feed.batch_results([result])
+ # WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745
+ # wait for all other nodes to complete (via done files)
+ done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode)
+ print("Writing done file to: {}".format(done_dir))
+ tf.gfile.MakeDirs(done_dir)
+ with tf.gfile.GFile("{}/{}".format(done_dir, ctx.task_index), 'w') as done_file:
+ done_file.write("done")
+
+ for i in range(60):
+ if len(tf.gfile.ListDirectory(done_dir)) < len(ctx.cluster_spec['worker']):
+ print("{} Waiting for other nodes {}".format(datetime.now().isoformat(), i))
+ time.sleep(1)
+ else:
+ print("{} All nodes done".format(datetime.now().isoformat()))
+ break
+
if __name__ == '__main__':
import argparse
diff --git a/pom.xml b/pom.xml
index 5bcf35e8..99f3d86e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -5,7 +5,7 @@
4.0.0
com.yahoo.ml
tensorflowonspark
- 1.0
+ 1.0.1
jar
tensorflowonspark
Spark Scala inferencing for TensorFlowOnSpark
@@ -25,7 +25,7 @@
3.5.3
2.1
2.20.1
- 2.1.0
+ [2.2.0,)
2.11.8
3.2.1
1.1.0