From 99fd651b7ffc5c977ffb0dac1c254ac42bdbccb5 Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Fri, 9 Nov 2018 11:08:45 -0800 Subject: [PATCH 1/2] update pipeline.py to latest 1.12 API; add workaround to mnist_mlp_estimator.py --- examples/mnist/keras/mnist_mlp_estimator.py | 16 ++++++++++++++++ tensorflowonspark/pipeline.py | 5 +++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/mnist/keras/mnist_mlp_estimator.py b/examples/mnist/keras/mnist_mlp_estimator.py index 7f3329b6..570dad4b 100644 --- a/examples/mnist/keras/mnist_mlp_estimator.py +++ b/examples/mnist/keras/mnist_mlp_estimator.py @@ -104,6 +104,22 @@ def predict_input_fn(): for result in predictions: tf_feed.batch_results([result]) + # WORKAROUND FOR https://github.com/tensorflow/tensorflow/issues/21745 + # wait for all other nodes to complete (via done files) + done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode) + print("Writing done file to: {}".format(done_dir)) + tf.gfile.MakeDirs(done_dir) + with tf.gfile.GFile("{}/{}".format(done_dir, ctx.task_index), 'w') as done_file: + done_file.write("done") + + for i in range(60): + if len(tf.gfile.ListDirectory(done_dir)) < len(ctx.cluster_spec['worker']): + print("{} Waiting for other nodes {}".format(datetime.now().isoformat(), i)) + time.sleep(1) + else: + print("{} All nodes done".format(datetime.now().isoformat())) + break + if __name__ == '__main__': import argparse diff --git a/tensorflowonspark/pipeline.py b/tensorflowonspark/pipeline.py index e5a7325b..912d628f 100755 --- a/tensorflowonspark/pipeline.py +++ b/tensorflowonspark/pipeline.py @@ -23,7 +23,8 @@ from pyspark.sql import Row, SparkSession import tensorflow as tf -from tensorflow.contrib.saved_model.python.saved_model import reader, signature_def_utils + +from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.python.saved_model import loader from . import TFCluster, gpu_info, dfutil @@ -503,7 +504,7 @@ def _run_model(iterator, args, tf_args): assert args.export_dir, "Inferencing with signature_def_key requires --export_dir argument" logging.info("===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}".format(args.tag_set, args.export_dir)) meta_graph_def = get_meta_graph_def(args.export_dir, args.tag_set) - signature = signature_def_utils.get_signature_def_by_key(meta_graph_def, args.signature_def_key) + signature = meta_graph_def.signature_def[args.signature_def_key] logging.debug("signature: {}".format(signature)) inputs_tensor_info = signature.inputs logging.debug("inputs_tensor_info: {0}".format(inputs_tensor_info)) From c771175976b96af66f7c27cf4b9b900eccacb176 Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Fri, 9 Nov 2018 11:46:14 -0800 Subject: [PATCH 2/2] update spark version --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 5bcf35e8..99f3d86e 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ 4.0.0 com.yahoo.ml tensorflowonspark - 1.0 + 1.0.1 jar tensorflowonspark Spark Scala inferencing for TensorFlowOnSpark @@ -25,7 +25,7 @@ 3.5.3 2.1 2.20.1 - 2.1.0 + [2.2.0,) 2.11.8 3.2.1 1.1.0