diff --git a/examples/mnist/estimator/README.md b/examples/mnist/estimator/README.md index 256f3265..abc45c4e 100644 --- a/examples/mnist/estimator/README.md +++ b/examples/mnist/estimator/README.md @@ -1,8 +1,8 @@ # MNIST using Estimator -Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator +Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator -This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark. +This is the [Multi-worker Training with Estimator](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_estimator) example, adapted for TensorFlowOnSpark. Note: this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed. diff --git a/examples/mnist/estimator/mnist_pipeline.py b/examples/mnist/estimator/mnist_pipeline.py index b05d1c7c..0939df1f 100644 --- a/examples/mnist/estimator/mnist_pipeline.py +++ b/examples/mnist/estimator/mnist_pipeline.py @@ -45,8 +45,7 @@ def input_fn(mode, input_context=None): ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1]))) return ds.batch(BATCH_SIZE) else: - raise Exception("I'm evaluating: mode={}, input_context={}".format(mode, input_context)) - + # read evaluation data from tensorflow_datasets directly def scale(image, label): image = tf.cast(image, tf.float32) / 255.0 return image, label diff --git a/examples/mnist/estimator/mnist_spark.py b/examples/mnist/estimator/mnist_spark.py index 3433e2fc..4e787e8c 100644 --- a/examples/mnist/estimator/mnist_spark.py +++ b/examples/mnist/estimator/mnist_spark.py @@ -9,6 +9,8 @@ def main_fun(args, ctx): import tensorflow_datasets as tfds from tensorflowonspark import TFNode + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + tfds.disable_progress_bar() class StopFeedHook(tf.estimator.SessionRunHook): @@ -91,7 +93,6 @@ def model_fn(features, labels, mode): train_op=optimizer.minimize( loss, tf.compat.v1.train.get_or_create_global_step())) - strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100) classifier = tf.estimator.Estimator( diff --git a/examples/mnist/keras/README.md b/examples/mnist/keras/README.md index b67224dc..105f21be 100644 --- a/examples/mnist/keras/README.md +++ b/examples/mnist/keras/README.md @@ -1,8 +1,8 @@ # MNIST using Keras -Original Source: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras +Original Source: https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras -This is the [Multi-worker Training with Keras](https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark. +This is the [Multi-worker Training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) example, adapted for TensorFlowOnSpark. Notes: - This example assumes that Spark, TensorFlow, TensorFlow Datasets, and TensorFlowOnSpark are already installed. @@ -130,7 +130,7 @@ For batch inferencing use cases, you can use Spark to run multiple single-node T ${TFoS_HOME}/examples/mnist/keras/mnist_inference.py \ --cluster_size ${SPARK_WORKER_INSTANCES} \ --images_labels ${TFoS_HOME}/data/mnist/tfr/test \ - --export_dir ${TFoS_HOME}/mnist_export \ + --export_dir ${SAVED_MODEL} \ --output ${TFoS_HOME}/predictions #### Train and Inference via Spark ML Pipeline API diff --git a/examples/mnist/keras/mnist_pipeline.py b/examples/mnist/keras/mnist_pipeline.py index 75845bbc..00365070 100644 --- a/examples/mnist/keras/mnist_pipeline.py +++ b/examples/mnist/keras/mnist_pipeline.py @@ -6,7 +6,7 @@ def main_fun(args, ctx): import numpy as np import tensorflow as tf - from tensorflowonspark import TFNode + from tensorflowonspark import compat, TFNode strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() @@ -65,11 +65,9 @@ def rdd_generator(): multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks) - if ctx.job_name == 'chief': - from tensorflow_estimator.python.estimator.export import export_lib - export_dir = export_lib.get_timestamped_export_dir(args.export_dir) - tf.keras.experimental.export_saved_model(multi_worker_model, export_dir) - # multi_worker_model.save(args.model_dir, save_format='tf') + from tensorflow_estimator.python.estimator.export import export_lib + export_dir = export_lib.get_timestamped_export_dir(args.export_dir) + compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief') # terminating feed tells spark to skip processing further partitions tf_feed.terminate() diff --git a/examples/mnist/keras/mnist_spark.py b/examples/mnist/keras/mnist_spark.py index ab6ad340..25e82d39 100644 --- a/examples/mnist/keras/mnist_spark.py +++ b/examples/mnist/keras/mnist_spark.py @@ -6,7 +6,7 @@ def main_fun(args, ctx): import numpy as np import tensorflow as tf - from tensorflowonspark import TFNode + from tensorflowonspark import compat, TFNode strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() @@ -65,11 +65,9 @@ def rdd_generator(): multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks) - if ctx.job_name == 'chief': - from tensorflow_estimator.python.estimator.export import export_lib - export_dir = export_lib.get_timestamped_export_dir(args.export_dir) - tf.keras.experimental.export_saved_model(multi_worker_model, export_dir) - # multi_worker_model.save(args.model_dir, save_format='tf') + from tensorflow_estimator.python.estimator.export import export_lib + export_dir = export_lib.get_timestamped_export_dir(args.export_dir) + compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief') # terminating feed tells spark to skip processing further partitions tf_feed.terminate() diff --git a/examples/mnist/keras/mnist_tf.py b/examples/mnist/keras/mnist_tf.py index 7eda3dc3..00ad7740 100644 --- a/examples/mnist/keras/mnist_tf.py +++ b/examples/mnist/keras/mnist_tf.py @@ -6,6 +6,8 @@ def main_fun(args, ctx): import tensorflow_datasets as tfds import tensorflow as tf + from tensorflowonspark import compat + tfds.disable_progress_bar() strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() @@ -60,11 +62,9 @@ def build_and_compile_cnn_model(): multi_worker_model = build_and_compile_cnn_model() multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch, callbacks=callbacks) - if ctx.job_name == 'chief': - from tensorflow_estimator.python.estimator.export import export_lib - export_dir = export_lib.get_timestamped_export_dir(args.export_dir) - tf.keras.experimental.export_saved_model(multi_worker_model, export_dir) - # multi_worker_model.save(args.model_dir, save_format='tf') + from tensorflow_estimator.python.estimator.export import export_lib + export_dir = export_lib.get_timestamped_export_dir(args.export_dir) + compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief') if __name__ == '__main__': diff --git a/examples/mnist/keras/mnist_tf_ds.py b/examples/mnist/keras/mnist_tf_ds.py index eaf8bcc6..1b7a5cd8 100644 --- a/examples/mnist/keras/mnist_tf_ds.py +++ b/examples/mnist/keras/mnist_tf_ds.py @@ -6,6 +6,7 @@ def main_fun(args, ctx): """Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets.""" import tensorflow as tf + from tensorflowonspark import compat strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() @@ -86,11 +87,9 @@ def build_and_compile_cnn_model(): multi_worker_model = build_and_compile_cnn_model() multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks) - if ctx.job_name == 'chief': - from tensorflow_estimator.python.estimator.export import export_lib - export_dir = export_lib.get_timestamped_export_dir(args.export_dir) - tf.keras.experimental.export_saved_model(multi_worker_model, export_dir) - # multi_worker_model.save(args.model_dir, save_format='tf') + from tensorflow_estimator.python.estimator.export import export_lib + export_dir = export_lib.get_timestamped_export_dir(args.export_dir) + compat.export_saved_model(multi_worker_model, export_dir, ctx.job_name == 'chief') if __name__ == '__main__': diff --git a/examples/segmentation/README.md b/examples/segmentation/README.md index 1faebe68..6cdaf299 100644 --- a/examples/segmentation/README.md +++ b/examples/segmentation/README.md @@ -1,8 +1,8 @@ # Image Segmentation -Original Source: https://www.tensorflow.org/beta/tutorials/images/segmentation +Original Source: https://www.tensorflow.org/tutorials/images/segmentation -This code is based on the [Image Segmentation](https://www.tensorflow.org/beta/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage. +This code is based on the [Image Segmentation](https://www.tensorflow.org/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage. Notes: - this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed. diff --git a/examples/segmentation/segmentation_spark.py b/examples/segmentation/segmentation_spark.py index 3743d3ca..becbd065 100644 --- a/examples/segmentation/segmentation_spark.py +++ b/examples/segmentation/segmentation_spark.py @@ -159,18 +159,15 @@ def unet_model(output_channels): validation_steps=VALIDATION_STEPS, validation_data=test_dataset) - if ctx.job_name == 'chief': + if tf.__version__ == '2.0.0': # Workaround for: https://github.com/tensorflow/tensorflow/issues/30251 - print("===== saving h5py model") - model.save(args.model_dir + ".h5") - print("===== re-loading model w/o DistributionStrategy") - new_model = tf.keras.models.load_model(args.model_dir + ".h5") - print("===== exporting saved_model") - tf.keras.experimental.export_saved_model(new_model, args.export_dir) - print("===== done exporting") + # Save model locally as h5py and reload it w/o distribution strategy + if ctx.job_name == 'chief': + model.save(args.model_dir + ".h5") + new_model = tf.keras.models.load_model(args.model_dir + ".h5") + tf.keras.experimental.export_saved_model(new_model, args.export_dir) else: - print("===== sleeping") - time.sleep(90) + model.save(args.export_dir, save_format='tf') if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 6be7c9b0..3da04035 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ h5py>=2.9.0 numpy>=1.14.0 +packaging py4j==0.10.7 pyspark scipy diff --git a/tensorflowonspark/TFNode.py b/tensorflowonspark/TFNode.py index 5b96de38..977b8252 100644 --- a/tensorflowonspark/TFNode.py +++ b/tensorflowonspark/TFNode.py @@ -16,6 +16,8 @@ import getpass import logging + +from packaging import version from six.moves.queue import Empty from . import marker @@ -61,8 +63,86 @@ def hdfs_path(ctx, path): def start_cluster_server(ctx, num_gpus=1, rdma=False): - """*DEPRECATED*. Use higher-level APIs like `tf.keras` or `tf.estimator`""" - raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`") + """Function that wraps the creation of TensorFlow ``tf.train.Server`` for a node in a distributed TensorFlow cluster. + + This is intended to be invoked from within the TF ``map_fun``, replacing explicit code to instantiate ``tf.train.ClusterSpec`` + and ``tf.train.Server`` objects. + + DEPRECATED for TensorFlow 2.x+ + + Args: + :ctx: TFNodeContext containing the metadata specific to this node in the cluster. + :num_gpu: number of GPUs desired + :rdma: boolean indicating if RDMA 'iverbs' should be used for cluster communications. + + Returns: + A tuple of (cluster_spec, server) + """ + import os + import tensorflow as tf + import time + from . import gpu_info + + if version.parse(tf.__version__) >= version.parse("2.0.0"): + raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`") + + logging.info("{0}: ======== {1}:{2} ========".format(ctx.worker_num, ctx.job_name, ctx.task_index)) + cluster_spec = ctx.cluster_spec + logging.info("{0}: Cluster spec: {1}".format(ctx.worker_num, cluster_spec)) + + if tf.test.is_built_with_cuda() and num_gpus > 0: + # compute my index relative to other nodes placed on the same host (for GPU allocation) + my_addr = cluster_spec[ctx.job_name][ctx.task_index] + my_host = my_addr.split(':')[0] + flattened = [v for sublist in cluster_spec.values() for v in sublist] + local_peers = [p for p in flattened if p.startswith(my_host)] + my_index = local_peers.index(my_addr) + + # GPU + gpu_initialized = False + retries = 3 + while not gpu_initialized and retries > 0: + try: + # override PS jobs to only reserve one GPU + if ctx.job_name == 'ps': + num_gpus = 0 + + # Find a free gpu(s) to use + gpus_to_use = gpu_info.get_gpus(num_gpus, my_index) + gpu_prompt = "GPU" if num_gpus == 1 else "GPUs" + logging.info("{0}: Using {1}: {2}".format(ctx.worker_num, gpu_prompt, gpus_to_use)) + + # Set GPU device to use for TensorFlow + os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use + + # Create a cluster from the parameter server and worker hosts. + cluster = tf.train.ClusterSpec(cluster_spec) + + # Create and start a server for the local task. + if rdma: + server = tf.train.Server(cluster, ctx.job_name, ctx.task_index, protocol="grpc+verbs") + else: + server = tf.train.Server(cluster, ctx.job_name, ctx.task_index) + gpu_initialized = True + except Exception as e: + print(e) + logging.error("{0}: Failed to allocate GPU, trying again...".format(ctx.worker_num)) + retries -= 1 + time.sleep(10) + if not gpu_initialized: + raise Exception("Failed to allocate GPU") + else: + # CPU + os.environ['CUDA_VISIBLE_DEVICES'] = '' + logging.info("{0}: Using CPU".format(ctx.worker_num)) + + # Create a cluster from the parameter server and worker hosts. + cluster = tf.train.ClusterSpec(cluster_spec) + + # Create and start a server for the local task. + server = tf.train.Server(cluster, ctx.job_name, ctx.task_index) + + return (cluster, server) def next_batch(mgr, batch_size, qname='input'): @@ -71,8 +151,55 @@ def next_batch(mgr, batch_size, qname='input'): def export_saved_model(sess, export_dir, tag_set, signatures): - """*DEPRECATED*. Use TF provided APIs instead.""" - raise Exception("DEPRECATED: Use TF provided APIs instead.") + """Convenience function to export a saved_model using provided arguments + + The caller specifies the saved_model signatures in a simplified python dictionary form, as follows:: + + signatures = { + 'signature_def_key': { + 'inputs': { 'input_tensor_alias': input_tensor_name }, + 'outputs': { 'output_tensor_alias': output_tensor_name }, + 'method_name': 'method' + } + } + + And this function will generate the `signature_def_map` and export the saved_model. + + DEPRECATED for TensorFlow 2.x+. + + Args: + :sess: a tf.Session instance + :export_dir: path to save exported saved_model + :tag_set: string tag_set to identify the exported graph + :signatures: simplified dictionary representation of a TensorFlow signature_def_map + + Returns: + A saved_model exported to disk at ``export_dir``. + """ + import tensorflow as tf + + if version.parse(tf.__version__) >= version.parse("2.0.0"): + raise Exception("DEPRECATED: Use TF provided APIs instead.") + + g = sess.graph + g._unsafe_unfinalize() # https://github.com/tensorflow/serving/issues/363 + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + + logging.info("===== signatures: {}".format(signatures)) + signature_def_map = {} + for key, sig in signatures.items(): + signature_def_map[key] = tf.saved_model.signature_def_utils.build_signature_def( + inputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['inputs'].items()}, + outputs={name: tf.saved_model.utils.build_tensor_info(tensor) for name, tensor in sig['outputs'].items()}, + method_name=sig['method_name'] if 'method_name' in sig else key) + logging.info("===== signature_def_map: {}".format(signature_def_map)) + builder.add_meta_graph_and_variables( + sess, + tag_set.split(','), + signature_def_map=signature_def_map, + clear_devices=True) + g.finalize() + builder.save() def batch_results(mgr, results, qname='output'): diff --git a/tensorflowonspark/compat.py b/tensorflowonspark/compat.py new file mode 100644 index 00000000..4fc9ca18 --- /dev/null +++ b/tensorflowonspark/compat.py @@ -0,0 +1,23 @@ +# Copyright 2019 Yahoo Inc / Verizon Media +# Licensed under the terms of the Apache 2.0 license. +# Please see LICENSE file in the project root for terms. +"""Helper functions to abstract API changes between TensorFlow versions.""" + +import tensorflow as tf + +TF_VERSION = tf.__version__ + + +def export_saved_model(model, export_dir, is_chief=False): + if TF_VERSION == '2.0.0': + if is_chief: + tf.keras.experimental.export_saved_model(model, export_dir) + else: + model.save(export_dir, save_format='tf') + + +def disable_auto_shard(options): + if TF_VERSION == '2.0.0': + options.experimental_distribute.auto_shard = False + else: + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF diff --git a/tensorflowonspark/pipeline.py b/tensorflowonspark/pipeline.py index bc7e37c6..d97fe39e 100755 --- a/tensorflowonspark/pipeline.py +++ b/tensorflowonspark/pipeline.py @@ -20,13 +20,15 @@ from pyspark.ml.pipeline import Estimator, Model from pyspark.sql import Row, SparkSession -import tensorflow as tf -from . import TFCluster, util - import argparse import copy import logging import sys +import tensorflow as tf + +from . import TFCluster, util +from packaging import version + logger = logging.getLogger(__name__) @@ -108,6 +110,22 @@ def getInputMapping(self): return self.getOrDefault(self.input_mapping) +class HasInputMode(Params): + input_mode = Param(Params._dummy(), "input_mode", "Input data feeding mode (0=TENSORFLOW, 1=SPARK)", typeConverter=TypeConverters.toInt) + + def __init__(self): + super(HasInputMode, self).__init__() + + def setInputMode(self, value): + if value == TFCluster.InputMode.TENSORFLOW: + raise Exception("InputMode.TENSORFLOW is deprecated") + + return self._set(input_mode=value) + + def getInputMode(self): + return self.getOrDefault(self.input_mode) + + class HasMasterNode(Params): master_node = Param(Params._dummy(), "master_node", "Job name of master/chief worker node", typeConverter=TypeConverters.toString) @@ -330,7 +348,7 @@ def merge_args_params(self): class TFEstimator(Estimator, TFParams, HasInputMapping, - HasClusterSize, HasNumPS, HasMasterNode, HasProtocol, HasGraceSecs, + HasClusterSize, HasNumPS, HasInputMode, HasMasterNode, HasProtocol, HasGraceSecs, HasTensorboard, HasModelDir, HasExportDir, HasTFRecordDir, HasBatchSize, HasEpochs, HasReaders, HasSteps): """Spark ML Estimator which launches a TensorFlowOnSpark cluster for distributed training. @@ -341,20 +359,24 @@ class TFEstimator(Estimator, TFParams, HasInputMapping, Args: :train_fn: TensorFlow "main" function for training. :tf_args: Arguments specific to the TensorFlow "main" function. + :export_fn: TensorFlow function for exporting a saved_model. DEPRECATED for TF2.x. """ train_fn = None export_fn = None - def __init__(self, train_fn, tf_args): + def __init__(self, train_fn, tf_args, export_fn=None): super(TFEstimator, self).__init__() self.train_fn = train_fn self.args = Namespace(tf_args) + + master_node = 'chief' if version.parse(tf.__version__) >= version.parse("2.0.0") else None self._setDefault(input_mapping={}, cluster_size=1, num_ps=0, driver_ps_nodes=False, - master_node='chief', + input_mode=TFCluster.InputMode.SPARK, + master_node=master_node, protocol='grpc', tensorboard=False, model_dir=None, @@ -390,6 +412,22 @@ def _fit(self, dataset): cluster.train(dataset.select(input_cols).rdd, local_args.epochs) cluster.shutdown(grace_secs=self.getGraceSecs()) + if self.export_fn: + if version.parse(tf.__version__) < version.parse("2.0.0"): + # For TF1.x, run export function, if provided + assert local_args.export_dir, "Export function requires --export_dir to be set" + logging.info("Exporting saved_model (via export_fn) to: {}".format(local_args.export_dir)) + + def _export(iterator, fn, args): + single_node_env(args) + fn(args) + + # Run on a single exeucutor + sc.parallelize([1], 1).foreachPartition(lambda it: _export(it, self.export_fn, tf_args)) + else: + # for TF2.x + raise Exception("Please use native TF2.x APIs to export a saved_model.") + return self._copyValues(TFModel(self.args)) @@ -415,8 +453,8 @@ def __init__(self, tf_args): batch_size=100, model_dir=None, export_dir=None, - signature_def_key='serving_default', - tag_set='serve') + signature_def_key=None, + tag_set=None) def _transform(self, dataset): """Transforms the input DataFrame by applying the _run_model() mapPartitions function. @@ -441,6 +479,8 @@ def _transform(self, dataset): logger.info("===== 3. inference args + params: {0}".format(local_args)) tf_args = self.args.argv if self.args.argv else local_args + + _run_model = _run_model_tf1 if version.parse(tf.__version__) < version.parse("2.0.0") else _run_model_tf2 rdd_out = dataset.select(input_cols).rdd.mapPartitions(lambda it: _run_model(it, local_args, tf_args)) # convert to a DataFrame-friendly format @@ -450,11 +490,13 @@ def _transform(self, dataset): # global on each python worker process on the executors pred_fn = None # saved_model prediction function/signature. -pred_args = None # args provided to the _run_model() method. Any change will invalidate the pred_fn. +global_sess = None # tf.Session cache (TF1.x) +global_args = None # args provided to the _run_model() method. Any change will invalidate the pred_fn. +global_model = None # this needs to be global for TF2.1+ -def _run_model(iterator, args, tf_args): - """mapPartitions function to run single-node inferencing from a saved_model, using input/output mappings. +def _run_model_tf1(iterator, args, tf_args): + """mapPartitions function (for TF1.x) to run single-node inferencing from a saved_model, using input/output mappings. Args: :iterator: input RDD partition iterator. @@ -464,6 +506,82 @@ def _run_model(iterator, args, tf_args): Returns: An iterator of result data. """ + from tensorflow.python.saved_model import loader + + single_node_env(tf_args) + + logger.info("===== input_mapping: {}".format(args.input_mapping)) + logger.info("===== output_mapping: {}".format(args.output_mapping)) + input_tensor_names = [tensor for col, tensor in sorted(args.input_mapping.items())] + output_tensor_names = [tensor for tensor, col in sorted(args.output_mapping.items())] + + # if using a signature_def_key, get input/output tensor info from the requested signature + if version.parse(tf.__version__) < version.parse("2.0.0") and args.signature_def_key: + assert args.export_dir, "Inferencing with signature_def_key requires --export_dir argument" + logging.info("===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}".format(args.tag_set, args.export_dir)) + meta_graph_def = get_meta_graph_def(args.export_dir, args.tag_set) + signature = meta_graph_def.signature_def[args.signature_def_key] + logging.debug("signature: {}".format(signature)) + inputs_tensor_info = signature.inputs + logging.debug("inputs_tensor_info: {0}".format(inputs_tensor_info)) + outputs_tensor_info = signature.outputs + logging.debug("outputs_tensor_info: {0}".format(outputs_tensor_info)) + + result = [] + global global_sess, global_args + if global_sess and global_args == args: + # if graph/session already loaded/started (and using same args), just reuse it + sess = global_sess + else: + # otherwise, create new session and load graph from disk + tf.reset_default_graph() + sess = tf.Session(graph=tf.get_default_graph()) + if args.export_dir: + assert args.tag_set, "Inferencing from a saved_model requires --tag_set" + # load graph from a saved_model + logging.info("===== restoring from saved_model: {}".format(args.export_dir)) + loader.load(sess, args.tag_set.split(','), args.export_dir) + elif args.model_dir: + # load graph from a checkpoint + ckpt = tf.train.latest_checkpoint(args.model_dir) + assert ckpt, "Invalid model checkpoint path: {}".format(args.model_dir) + logging.info("===== restoring from checkpoint: {}".format(ckpt + ".meta")) + saver = tf.train.import_meta_graph(ckpt + ".meta", clear_devices=True) + saver.restore(sess, ckpt) + else: + raise Exception("Inferencing requires either --model_dir or --export_dir argument") + global_sess = sess + global_args = args + + # get list of input/output tensors (by name) + if args.signature_def_key: + input_tensors = [inputs_tensor_info[t].name for t in input_tensor_names] + output_tensors = [outputs_tensor_info[t].name for t in output_tensor_names] + else: + input_tensors = [t + ':0' for t in input_tensor_names] + output_tensors = [t + ':0' for t in output_tensor_names] + + logging.info("input_tensors: {0}".format(input_tensors)) + logging.info("output_tensors: {0}".format(output_tensors)) + + # feed data in batches and return output tensors + for tensors in yield_batch(iterator, args.batch_size, len(input_tensor_names)): + inputs_feed_dict = {} + for i in range(len(input_tensors)): + inputs_feed_dict[input_tensors[i]] = tensors[i] + + outputs = sess.run(output_tensors, feed_dict=inputs_feed_dict) + lengths = [len(output) for output in outputs] + input_size = len(tensors[0]) + assert all([length == input_size for length in lengths]), "Output array sizes {} must match input size: {}".format(lengths, input_size) + python_outputs = [output.tolist() for output in outputs] # convert from numpy to standard python types + result.extend(zip(*python_outputs)) # convert to an array of tuples of "output columns" + + return result + + +def _run_model_tf2(iterator, args, tf_args): + """mapPartitions function (for TF2.x) to run single-node inferencing from a saved_model, using input/output mappings.""" single_node_env(tf_args) logger.info("===== input_mapping: {}".format(args.input_mapping)) @@ -471,16 +589,16 @@ def _run_model(iterator, args, tf_args): input_tensor_names = [tensor for col, tensor in sorted(args.input_mapping.items())] output_tensor_names = [tensor for tensor, col in sorted(args.output_mapping.items())] - global pred_fn, pred_args + global pred_fn, global_args, global_model - # cache saved_model pred_fn to avoid reloading the model for each partition - if not pred_fn or args != pred_args: + if not pred_fn or args != global_args: + # cache pred_fn to avoid reloading model for each partition assert args.export_dir, "Inferencing requires --export_dir argument" logger.info("===== loading saved_model from: {}".format(args.export_dir)) - saved_model = tf.saved_model.load(args.export_dir, tags=args.tag_set) + global_model = tf.saved_model.load(args.export_dir, tags=args.tag_set) logger.info("===== signature_def_key: {}".format(args.signature_def_key)) - pred_fn = saved_model.signatures[args.signature_def_key] - pred_args = args + pred_fn = global_model.signatures[args.signature_def_key] + global_args = args inputs_tensor_info = {i.name: i for i in pred_fn.inputs} logger.info("===== inputs_tensor_info: {0}".format(inputs_tensor_info)) @@ -539,6 +657,30 @@ def single_node_env(args): util.single_node_env(num_gpus) +def get_meta_graph_def(saved_model_dir, tag_set): + """Utility function to read a meta_graph_def from disk. + + From `saved_model_cli.py `_ + + DEPRECATED for TF2.0+ + + Args: + :saved_model_dir: path to saved_model. + :tag_set: list of string tags identifying the TensorFlow graph within the saved_model. + + Returns: + A TensorFlow meta_graph_def, or raises an Exception otherwise. + """ + from tensorflow.contrib.saved_model.python.saved_model import reader + + saved_model = reader.read_saved_model(saved_model_dir) + set_of_tags = set(tag_set.split(',')) + for meta_graph_def in saved_model.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set_of_tags: + return meta_graph_def + raise RuntimeError("MetaGraphDef associated with tag-set {0} could not be found in SavedModel".format(tag_set)) + + def yield_batch(iterable, batch_size, num_tensors=1): """Generator that yields batches of a DataFrame iterator. diff --git a/test/test_pipeline.py b/test/test_pipeline.py index 3fe1c19d..e848ae00 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -4,6 +4,7 @@ import test import unittest +from tensorflowonspark import compat from tensorflowonspark.pipeline import HasBatchSize, HasSteps, Namespace, TFEstimator, TFParams from tensorflow.keras import Sequential from tensorflow.keras.layers import Dense @@ -115,12 +116,11 @@ def rdd_generator(): return ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([2]), tf.TensorShape([1]))) - ds = ds.batch(args.batch_size) - - # disable auto-sharding dataset + # disable auto-sharding since we're feeding from an RDD generator options = tf.data.Options() - options.experimental_distribute.auto_shard = False + compat.disable_auto_shard(options) ds = ds.with_options(options) + ds = ds.batch(args.batch_size) # only train 90% of each epoch to account for uneven RDD partition sizes steps_per_epoch = 1000 * 0.9 // (args.batch_size * ctx.num_workers) @@ -134,9 +134,9 @@ def rdd_generator(): # This fails with: "NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy" # model.fit_generator(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks) - if ctx.job_name == 'chief' and args.export_dir: + if args.export_dir: print("exporting model to: {}".format(args.export_dir)) - tf.keras.experimental.export_saved_model(model, args.export_dir) + compat.export_saved_model(model, args.export_dir, ctx.job_name == 'chief') tf_feed.terminate() @@ -151,6 +151,7 @@ def rdd_generator(): .setModelDir(self.model_dir) \ .setExportDir(self.export_dir) \ .setClusterSize(self.num_workers) \ + .setMasterNode("chief") \ .setNumPS(0) \ .setBatchSize(1) \ .setEpochs(1)