diff --git a/tensorflow/python/compiler/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py index 6aa32f73676e34..f42bbf272293df 100644 --- a/tensorflow/python/compiler/tensorrt/test/base_test.py +++ b/tensorflow/python/compiler/tensorrt/test/base_test.py @@ -144,7 +144,7 @@ def GetConversionParams(self, run_params): ).GetConversionParams(run_params)._replace( # Disable layout optimizer, since it'll add Transpose(Const, Const) to # the graph and breaks the conversion check. - rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) + rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig()) class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase): diff --git a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py index 2b7bbbc960558a..69d4ab0e297acf 100644 --- a/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py @@ -124,7 +124,7 @@ def GetConversionParams(self, run_params): maximum_cached_engines=1, # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. - rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) + rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig()) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py index cb358d4f9bd91d..a906071b2c7425 100644 --- a/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py +++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py @@ -85,7 +85,7 @@ def GetConversionParams(self, run_params): maximum_cached_engines=10, # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. - rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) + rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig()) def ExpectedEnginesToBuild(self, run_params): return ["TRTEngineOp_0"] diff --git a/tensorflow/python/compiler/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py index 6d4446940aadf2..41a5a27addc944 100644 --- a/tensorflow/python/compiler/tensorrt/test/int32_test.py +++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py @@ -65,7 +65,7 @@ def GetConversionParams(self, run_params): maximum_cached_engines=1, # Disable layout optimizer, since it will convert BiasAdd with NHWC # format to NCHW format under four dimentional input. - rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) + rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig()) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 345f8bbf8eb2e6..f499d83189c77f 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -56,12 +56,6 @@ "use_calibration" ]) -ConversionParams = namedtuple("ConversionParams", [ - "max_batch_size", "max_workspace_size_bytes", "precision_mode", - "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "cached_engine_batches", "rewriter_config", "use_calibration" -]) - PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -163,7 +157,7 @@ def GetParams(self): raise NotImplementedError() def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" + """Return a TrtConversionParams for test.""" batch_list = [] for dims_list in self._GetParamsCached().input_dims: assert dims_list @@ -171,19 +165,22 @@ def GetConversionParams(self, run_params): input_batches = [dims[0] for dims in dims_list] assert max(input_batches) == min(input_batches) batch_list.append(input_batches[0]) - return ConversionParams( + conversion_params = trt_convert.TrtConversionParams( # We use the minimum of all the batch sizes, so when multiple different # input shapes are provided it'll always create new engines in the # cache, and we can therefore test the cache behavior. - max_batch_size=min(batch_list), + rewriter_config_template=None, max_workspace_size_bytes=1 << 25, precision_mode=run_params.precision_mode, minimum_segment_size=2, is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, - cached_engine_batches=None, - rewriter_config=None, - use_calibration=run_params.use_calibration) + use_calibration=run_params.use_calibration, + use_function_backup=False, + max_batch_size=min(batch_list), + cached_engine_batches=None) + return conversion_params._replace( + use_function_backup=IsQuantizationWithCalibration(conversion_params)) def ShouldRunTest(self, run_params): """Whether to run the test.""" @@ -218,24 +215,13 @@ def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" conversion_params = self.GetConversionParams(run_params) if graph_state == GraphState.INFERENCE and run_params.use_optimizer: - rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( - conversion_params.rewriter_config, - conversion_params.max_batch_size, - conversion_params.max_workspace_size_bytes, - conversion_params.precision_mode, - conversion_params.minimum_segment_size, - conversion_params.is_dynamic_op, - conversion_params.maximum_cached_engines, - conversion_params.cached_engine_batches, - conversion_params.use_calibration, - use_function_backup=IsQuantizationWithCalibration(conversion_params)) - + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() - if conversion_params.rewriter_config is not None: + if conversion_params.rewriter_config_template is not None: graph_options.rewrite_options.CopyFrom( - conversion_params.rewriter_config) + conversion_params.rewriter_config_template) config = config_pb2.ConfigProto( gpu_options=self._GetGPUOptions(), graph_options=graph_options) @@ -310,7 +296,7 @@ def _CreateConverter(self, gdef, session_config, conversion_params): maximum_cached_engines=conversion_params.maximum_cached_engines, cached_engine_batches=conversion_params.cached_engine_batches, use_calibration=conversion_params.use_calibration, - use_function_backup=IsQuantizationWithCalibration(conversion_params)) + use_function_backup=conversion_params.use_function_backup) return converter def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data): diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 326a7d048c9036..181fa3621504e0 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -18,7 +18,10 @@ from __future__ import division from __future__ import print_function +import collections + import six as _six + from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_loaded_tensorrt_version @@ -80,7 +83,7 @@ class GraphConverter(object): class MyGraphConverter(GraphConverter): ... - def get_rewriter_config(self, rewriter_config_template=None): + def get_rewriter_config(self): my_rewriter_config = ... return my_rewriter_config ``` @@ -129,7 +132,7 @@ def __init__(self, If set to None, the graph will be read from the SavedModel loaded from input_saved_model_dir. nodes_blacklist: list of node names to prevent the converter from - touching. Only used when input_graph_def is not None. + touching. session_config: the ConfigProto used to create a Session. It's also used as a template to create a RewriterConfig for conversion. If not specified, a default ConfigProto will be used. @@ -137,21 +140,15 @@ def __init__(self, Raises: ValueError: if the combination of the parameters is invalid. """ - if context.executing_eagerly(): - if input_graph_def or not input_saved_model_dir: - raise ValueError( - "TF 2.0 only supports conversion of SavedModel, please specify " - "input_saved_model_dir as input.") - else: - if input_graph_def and input_saved_model_dir: - raise ValueError( - "Can only specify one of input_graph_def and input_saved_model_dir") - if not input_graph_def and not input_saved_model_dir: - raise ValueError("Must specify one of input_graph_def and " - "input_saved_model_dir") + if input_graph_def and input_saved_model_dir: + raise ValueError( + "Can only specify one of input_graph_def and input_saved_model_dir") + if not input_graph_def and not input_saved_model_dir: + raise ValueError("Must specify one of input_graph_def and " + "input_saved_model_dir") - self._input_graph_def = input_graph_def - self._nodes_blacklist = nodes_blacklist + self._input_graph_def = input_graph_def + self._nodes_blacklist = nodes_blacklist self._input_saved_model_dir = input_saved_model_dir self._converted = False @@ -169,14 +166,9 @@ def __init__(self, self._calibration_sess = None self._calibration_data_collected = False - def get_rewriter_config(self, rewriter_config_template=None): + def get_rewriter_config(self): """Returns a RewriterConfig proto for TRT transformation. - Args: - rewriter_config_template: a template RewriterConfig proto used to create a - RewriterConfig for the conversion. The implementation should not modify - the template. If None, it will use a default one. - Returns: A RewriterConfig proto which will be used to run the conversion using Grappler. @@ -188,11 +180,7 @@ def _run_conversion(self): # Create custom ConfigProto for Grappler. grappler_session_config = config_pb2.ConfigProto() grappler_session_config.CopyFrom(self._session_config) - rewriter_config = None - if (grappler_session_config.HasField("graph_options") and - grappler_session_config.graph_options.HasField("rewrite_options")): - rewriter_config = grappler_session_config.graph_options.rewrite_options - custom_rewriter_config = self.get_rewriter_config(rewriter_config) + custom_rewriter_config = self.get_rewriter_config() grappler_session_config.graph_options.rewrite_options.CopyFrom( custom_rewriter_config) @@ -285,33 +273,6 @@ def _gather_names(tensor_info): self._run_conversion() - # TODO(laigd): provide a utility function to optimize a ConcreteFunction and - # use it here (b/124792963). - def _convert_saved_model_v2(self): - """Convert the input SavedModel in 2.0 format.""" - assert context.executing_eagerly() - - self._saved_model = load.load(self._input_saved_model_dir, - self._input_saved_model_tags) - func = self._saved_model.signatures[self._input_saved_model_signature_key] - frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) - self._grappler_meta_graph_def = saver.export_meta_graph( - graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) - - # Add a collection 'train_op' so that Grappler knows the outputs. - fetch_collection = meta_graph_pb2.CollectionDef() - for array in frozen_func.inputs + frozen_func.outputs: - fetch_collection.node_list.value.append(array.name) - self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom( - fetch_collection) - - # Run TRT optimizer in Grappler to convert the graph. - self._run_conversion() - self._converted_func = wrap_function.function_from_graph_def( - self._converted_graph_def, - [tensor.name for tensor in frozen_func.inputs], - [tensor.name for tensor in frozen_func.outputs]) - def convert(self): """Run the conversion. @@ -320,16 +281,11 @@ def convert(self): 2.0+. """ assert not self._converted - - if context.executing_eagerly(): - self._convert_saved_model_v2() - return self._converted_func + if self._input_graph_def: + self._convert_graph_def() else: - if self._input_graph_def: - self._convert_graph_def() - else: - self._convert_saved_model() - return self._converted_graph_def + self._convert_saved_model() + return self._converted_graph_def def calibrate(self, fetch_names, @@ -408,80 +364,71 @@ def save(self, output_saved_model_dir): SavedModel. """ assert self._converted - - if context.executing_eagerly(): - # Rewrite the signature map using the optimized ConcreteFunction. - signatures = { - key: value for key, value in self._saved_model.signatures.items() - } - signatures[self._input_saved_model_signature_key] = self._converted_func - save.save(self._saved_model, output_saved_model_dir, signatures) - else: - if self._input_graph_def: - raise ValueError( - "Not able to save to a SavedModel since input is a GraphDef") - - def _restore_collections(dest_graph, src_meta_graph_def, collections): - """Restores collections that we need to keep.""" - scope = "" - for key in collections: - collection_def = src_meta_graph_def.collection_def[key] - kind = collection_def.WhichOneof("kind") - if kind is None: - tf_logging.error( - "Cannot identify data type for collection %s. Skipping.", key) - continue - from_proto = ops.get_from_proto_function(key) - if from_proto and kind == "bytes_list": - proto_type = ops.get_collection_proto_type(key) - # It is assumed that there are no Variables Keys in collections - for value in collection_def.bytes_list.value: - proto = proto_type() - proto.ParseFromString(value) + if self._input_graph_def: + raise ValueError( + "Not able to save to a SavedModel since input is a GraphDef") + + def _restore_collections(dest_graph, src_meta_graph_def, collection_keys): + """Restores collections that we need to keep.""" + scope = "" + for key in collection_keys: + collection_def = src_meta_graph_def.collection_def[key] + kind = collection_def.WhichOneof("kind") + if kind is None: + tf_logging.error( + "Cannot identify data type for collection %s. Skipping.", key) + continue + from_proto = ops.get_from_proto_function(key) + if from_proto and kind == "bytes_list": + proto_type = ops.get_collection_proto_type(key) + # It is assumed that there are no Variables Keys in collections + for value in collection_def.bytes_list.value: + proto = proto_type() + proto.ParseFromString(value) + try: + new_value = from_proto(proto, import_scope=scope) + except: + continue + dest_graph.add_to_collection(key, new_value) + else: + field = getattr(collection_def, kind) + if kind == "node_list": + for value in field.value: + name = ops.prepend_name_scope(value, scope) + # Since the graph has been optimized, the node may no longer + # exists try: - new_value = from_proto(proto, import_scope=scope) - except: + col_op = dest_graph.as_graph_element(name) + except (TypeError, ValueError, KeyError) as e: continue - dest_graph.add_to_collection(key, new_value) + dest_graph.add_to_collection(key, col_op) + elif kind == "int64_list": + # NOTE(opensource): This force conversion is to work around the + # fact that Python2 distinguishes between int and long, while + # Python3 has only int. + for value in field.value: + dest_graph.add_to_collection(key, int(value)) else: - field = getattr(collection_def, kind) - if kind == "node_list": - for value in field.value: - name = ops.prepend_name_scope(value, scope) - # Since the graph has been optimized, the node may no longer - # exists - try: - col_op = dest_graph.as_graph_element(name) - except (TypeError, ValueError, KeyError) as e: - continue - dest_graph.add_to_collection(key, col_op) - elif kind == "int64_list": - # NOTE(opensource): This force conversion is to work around the - # fact that Python2 distinguishes between int and long, while - # Python3 has only int. - for value in field.value: - dest_graph.add_to_collection(key, int(value)) - else: - for value in field.value: - dest_graph.add_to_collection( - key, ops.prepend_name_scope(value, scope)) - - # Write the transformed graphdef as SavedModel. - saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) - with ops.Graph().as_default(): - importer.import_graph_def(self._converted_graph_def, name="") - _restore_collections( - ops.get_default_graph(), self._grappler_meta_graph_def, - self._collections_to_keep( - self._grappler_meta_graph_def.collection_def)) - # We don't use any specific converter here. - with session.Session(config=self._session_config) as sess: - saved_model_builder.add_meta_graph_and_variables( - sess, - self._input_saved_model_tags, - signature_def_map=self._grappler_meta_graph_def.signature_def) - # Ignore other meta graphs from the input SavedModel. - saved_model_builder.save() + for value in field.value: + dest_graph.add_to_collection(key, + ops.prepend_name_scope(value, scope)) + + # Write the transformed graphdef as SavedModel. + saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) + with ops.Graph().as_default(): + importer.import_graph_def(self._converted_graph_def, name="") + _restore_collections( + ops.get_default_graph(), self._grappler_meta_graph_def, + self._collections_to_keep( + self._grappler_meta_graph_def.collection_def)) + # We don't use any specific converter here. + with session.Session(config=self._session_config) as sess: + saved_model_builder.add_meta_graph_and_variables( + sess, + self._input_saved_model_tags, + signature_def_map=self._grappler_meta_graph_def.signature_def) + # Ignore other meta graphs from the input SavedModel. + saved_model_builder.save() class TrtPrecisionMode(object): @@ -498,101 +445,202 @@ def supported_precision_modes(): # so it can produce reasonable performance results with the default. DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 +# TrtConversionParams encapsulates the parameters that are used for TF-TRT +# conversion. +TrtConversionParams = collections.namedtuple( + "TrtConversionParams", + [ + + # A template RewriterConfig proto used to create a TRT-enabled + # RewriterConfig. If None, it will use a default one. + "rewriter_config_template", + + # The maximum GPU temporary memory which the TRT engine can use at + # execution time. This corresponds to the 'workspaceSize' parameter of + # nvinfer1::IBuilder::setMaxWorkspaceSize(). + "max_workspace_size_bytes", + + # One of TrtPrecisionMode.supported_precision_modes(). + "precision_mode", + + # The minimum number of nodes required for a subgraph to be replaced by + # TRTEngineOp. + "minimum_segment_size", + + # Whether to generate dynamic TRT ops which will build the TRT network + # and engine at run time. + "is_dynamic_op", + + # Max number of cached TRT engines in dynamic TRT ops. If the number of + # cached engines is already at max but none of them can serve the input, + # the TRTEngineOp will fall back to run the TF function based on which + # the TRTEngineOp is created. + "maximum_cached_engines", + + # This argument is ignored if precision_mode is not INT8. If set to + # True, a calibration graph will be created to calibrate the missing + # ranges. The calibration graph must be converted to an inference graph + # by running calibration with calibrate(). If set to False, quantization + # nodes will be expected for every tensor in the graph (exlcuding those + # which will be fused). If a range is missing, an error will occur. + # Please note that accuracy may be negatively affected if there is a + # mismatch between which tensors TRT quantizes and which tensors were + # trained with fake quantization. + "use_calibration", + + # If set to True, it will create a FunctionDef for each subgraph that is + # converted to TRT op, and if TRT ops fail to execute at runtime, it'll + # invoke that function as a fallback. + "use_function_backup", + + # Max size for the input batch. + # This option is deprecated in TF 2.0. + "max_batch_size", + + # A list of batch sizes used to create cached engines, only used when + # is_dynamic_op is True. The length of the list should be <= + # maximum_cached_engines, and the dynamic TRT op will use this list to + # determine the batch sizes of the cached engines, instead of making the + # decision on the fly. This is useful when we know the most common batch + # size(s) the application is going to generate. + # This option is deprecated in TF 2.0. + "cached_engine_batches", + ]) + +DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams( + rewriter_config_template=None, + max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, + precision_mode=TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + use_calibration=True, + use_function_backup=True, + max_batch_size=1, + cached_engine_batches=None) -class TrtGraphConverter(GraphConverter): - """A GraphConverter for TRT transformation.""" +_TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration" +_TRT_ENGINE_CACHE_CONTAINER_NAME = "TF-TRT-Engine-Cache" +_TRT_ENGINE_OP_NAME = "TRTEngineOp" - _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF-TRT-Calibration" - - @classmethod - def get_tensorrt_rewriter_config( - cls, - rewriter_config_template=None, - max_batch_size=1, - max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=None, - use_calibration=True, - use_function_backup=True): - """Returns a RewriterConfig proto for TRT transformation. - Args: - rewriter_config_template: a template RewriterConfig proto used to create a - TRT-enabled RewriterConfig. If None, it will use a default one. - max_batch_size: max size for the input batch - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the - 'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph - to be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the - TRT network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT - ops. If the number of cached engines is already at max but none of them - can serve the input, the TRTEngineOp will fall back to run the TF - function based on which the TRTEngineOp is created. - cached_engine_batches: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be <= maximum_cached_engines, and the dynamic TRT op will use - this list to determine the batch sizes of the cached engines, instead of - making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. - If set to True, a calibration graph will be created to calibrate the - missing ranges. The calibration graph must be converted to an inference - graph by running calibration with calibrate(). If set to False, - quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if - there is a mismatch between which tensors TRT quantizes and which - tensors were trained with fake quantization. - use_function_backup: if set to True, it will create a FunctionDef for each - subgraph that is converted to TRT op, and if TRT ops fail to execute at - runtime, it'll invoke that function as a fallback. +def _check_conversion_params(conversion_params): + """Validate the provided TrtConversionParams. - Returns: - A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. + Args: + conversion_params: a TrtConversionParams instance. - Raises: - TypeError: if any of the parameters are of unexpected type. - ValueError: if any of the parameters are of unexpected value. - """ - if rewriter_config_template is not None and not isinstance( - rewriter_config_template, rewriter_config_pb2.RewriterConfig): - raise TypeError( - "rewriter_config_template should be a RewriterConfig proto.") - - rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() - if rewriter_config_template is None: - # Layout optimizer may add Const nodes followed by Reshape nodes, thus we - # need to run constant folding again. - rewriter_config_with_trt.optimizers.extend( - ["constfold", "layout", "constfold"]) - rewriter_config_with_trt.meta_optimizer_iterations = ( - rewriter_config_pb2.RewriterConfig.ONE) - else: - rewriter_config_with_trt.CopyFrom(rewriter_config_template) - - optimizer = rewriter_config_with_trt.custom_optimizers.add() - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size - optimizer.parameter_map["max_batch_size"].i = max_batch_size - optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op - optimizer.parameter_map[ - "max_workspace_size_bytes"].i = max_workspace_size_bytes - optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) - optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines - if cached_engine_batches: - optimizer.parameter_map["cached_engine_batches"].list.i.extend( - cached_engine_batches) - optimizer.parameter_map["use_calibration"].b = use_calibration - optimizer.parameter_map["use_function_backup"].b = use_function_backup - return rewriter_config_with_trt + Raises: + TypeError: if any of the parameters are of unexpected type. + ValueError: if any of the parameters are of unexpected value. + """ + supported_precision_modes = TrtPrecisionMode.supported_precision_modes() + if conversion_params.precision_mode not in supported_precision_modes: + raise ValueError( + ("precision mode '{}' is not supported." + "It should be one of {}").format(conversion_params.precision_mode, + supported_precision_modes)) + if conversion_params.cached_engine_batches: + if not isinstance(conversion_params.cached_engine_batches, list): + raise TypeError("cached_engine_batches should be a list.") + if len(conversion_params.cached_engine_batches + ) > conversion_params.maximum_cached_engines: + raise ValueError("cached_engine_batches should not contain more than " + "maximum_cached_engines items.") + + +def _check_trt_version_compatibility(): + """Check compatibility of TensorRT version. + + Raises: + RuntimeError: if the TensorRT library version is incompatible. + """ + compiled_version = get_linked_tensorrt_version() + loaded_version = get_loaded_tensorrt_version() + tf_logging.info("Linked TensorRT version: %s" % str(compiled_version)) + tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version)) + version_mismatch = False + if loaded_version[0] < compiled_version[0]: + tf_logging.error( + "TensorRT version mismatch. Tensorflow was compiled against " + + "TensorRT %s but library loaded from environment is TensorRT %s" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version])) + + ". Please make sure that correct version of TensorRT " + + "is available in the system and added to ldconfig or LD_LIBRARY_PATH") + raise RuntimeError("Incompatible TensorRT library version") + for i in zip(loaded_version, compiled_version): + if i[0] != i[1]: + tf_logging.warn("TensorRT mismatch. Compiled against version " + + "%s, but loaded %s. Things may not work" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version]))) + version_mismatch = True + break + if not version_mismatch: + tf_logging.info("Running against TensorRT version %s" % + ".".join([str(x) for x in loaded_version])) + + +def get_tensorrt_rewriter_config( + conversion_params=DEFAULT_TRT_CONVERSION_PARAMS): + """Returns a RewriterConfig proto for TRT transformation. + + Args: + conversion_params: a TrtConversionParams instance. + + Returns: + A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. + + Raises: + TypeError: if any of the parameters are of unexpected type. + ValueError: if any of the parameters are of unexpected value. + """ + if conversion_params.rewriter_config_template is not None and not isinstance( + conversion_params.rewriter_config_template, + rewriter_config_pb2.RewriterConfig): + raise TypeError( + "rewriter_config_template should be a RewriterConfig proto.") + _check_conversion_params(conversion_params) + + rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() + if conversion_params.rewriter_config_template is None: + # Layout optimizer may add Const nodes followed by Reshape nodes, thus we + # need to run constant folding again. + rewriter_config_with_trt.optimizers.extend( + ["constfold", "layout", "constfold"]) + rewriter_config_with_trt.meta_optimizer_iterations = ( + rewriter_config_pb2.RewriterConfig.ONE) + else: + rewriter_config_with_trt.CopyFrom( + conversion_params.rewriter_config_template) + + optimizer = rewriter_config_with_trt.custom_optimizers.add() + optimizer.name = "TensorRTOptimizer" + optimizer.parameter_map[ + "minimum_segment_size"].i = conversion_params.minimum_segment_size + optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size + optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op + optimizer.parameter_map[ + "max_workspace_size_bytes"].i = conversion_params.max_workspace_size_bytes + optimizer.parameter_map["precision_mode"].s = _to_bytes( + conversion_params.precision_mode) + optimizer.parameter_map[ + "maximum_cached_engines"].i = conversion_params.maximum_cached_engines + if conversion_params.cached_engine_batches: + optimizer.parameter_map["cached_engine_batches"].list.i.extend( + conversion_params.cached_engine_batches) + optimizer.parameter_map[ + "use_calibration"].b = conversion_params.use_calibration + optimizer.parameter_map[ + "use_function_backup"].b = conversion_params.use_function_backup + return rewriter_config_with_trt + +class TrtGraphConverter(GraphConverter): + """A GraphConverter for TRT transformation.""" + + # TODO(laigd): use TrtConversionParams here. def __init__(self, input_saved_model_dir=None, input_saved_model_tags=None, @@ -621,7 +669,7 @@ def __init__(self, If set to None, the graph will be read from the SavedModel loaded from input_saved_model_dir. nodes_blacklist: list of node names to prevent the converter from - touching. Only used when input_graph_def is not None. + touching. session_config: the ConfigProto used to create a Session. It's also used as a template to create a TRT-enabled ConfigProto for conversion. If not specified, a default ConfigProto will be used. @@ -659,7 +707,6 @@ def __init__(self, Raises: ValueError: if the combination of the parameters is invalid. - RuntimeError: if the TensorRT library version is incompatible. """ super(TrtGraphConverter, self).__init__( input_saved_model_dir=input_saved_model_dir, @@ -668,54 +715,10 @@ def __init__(self, input_graph_def=input_graph_def, nodes_blacklist=nodes_blacklist, session_config=session_config) - - # TODO(laigd): move all the validations below to - # get_tensorrt_rewriter_config(). - # Check compatibility of TensorRT version. - compiled_version = get_linked_tensorrt_version() - loaded_version = get_loaded_tensorrt_version() - tf_logging.info("Linked TensorRT version: %s" % str(compiled_version)) - tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version)) - version_mismatch = False - if loaded_version[0] < compiled_version[0]: - tf_logging.error( - "TensorRT version mismatch. Tensorflow was compiled against " + - "TensorRT %s but library loaded from environment is TensorRT %s" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version])) + - ". Please make sure that correct version of TensorRT " + - "is available in the system and added to ldconfig or LD_LIBRARY_PATH") - raise RuntimeError("Incompatible TensorRT library version") - for i in zip(loaded_version, compiled_version): - if i[0] != i[1]: - tf_logging.warn("TensorRT mismatch. Compiled against version " + - "%s, but loaded %s. Things may not work" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version]))) - version_mismatch = True - break - if not version_mismatch: - tf_logging.info("Running against TensorRT version %s" % - ".".join([str(x) for x in loaded_version])) - - # Check input arguments. - supported_precision_modes = TrtPrecisionMode.supported_precision_modes() - if precision_mode not in supported_precision_modes: - raise ValueError( - ("precision mode '{}' is not supported." - "It should be one of {}").format(precision_mode, - supported_precision_modes)) - - if cached_engine_batches: - if not isinstance(cached_engine_batches, list): - raise TypeError("cached_engine_batches should be a list.") - if len(cached_engine_batches) > maximum_cached_engines: - raise ValueError("cached_engine_batches should not contain more than " - "maximum_cached_engines items.") + _check_trt_version_compatibility() self._need_calibration = ( precision_mode == TrtPrecisionMode.INT8 and use_calibration) - self._use_function_backup = use_function_backup # TODO(laigd): consider provide a mechanism to remove the fallback path # after calibration is done. @@ -724,31 +727,30 @@ def __init__(self, "Calibration requires enabling fallback to TF function execution.") # TODO(laigd): - # - Get rid of is_dynamic_op option, it should always be True, and it should - # accept N shapes as input. # - Verify in int8 mode that maximum_cached_engines and # cached_engine_batches are set appropriately. # - If it fails to build the int8 engine it should return error. - self._max_batch_size = max_batch_size - self._max_workspace_size_bytes = max_workspace_size_bytes - self._precision_mode = precision_mode - self._minimum_segment_size = minimum_segment_size - self._is_dynamic_op = is_dynamic_op - self._maximum_cached_engines = maximum_cached_engines - self._cached_engine_batches = cached_engine_batches - - def get_rewriter_config(self, rewriter_config_template=None): - return TrtGraphConverter.get_tensorrt_rewriter_config( - rewriter_config_template, - max_batch_size=self._max_batch_size, - max_workspace_size_bytes=self._max_workspace_size_bytes, - precision_mode=self._precision_mode, - minimum_segment_size=self._minimum_segment_size, - is_dynamic_op=self._is_dynamic_op, - maximum_cached_engines=self._maximum_cached_engines, - cached_engine_batches=self._cached_engine_batches, - use_calibration=self._need_calibration, - use_function_backup=self._use_function_backup) + rewriter_config_template = None + if (session_config and session_config.HasField("graph_options") and + session_config.graph_options.HasField("rewrite_options")): + rewriter_config_template = session_config.graph_options.rewrite_options + + self._conversion_params = TrtConversionParams( + rewriter_config_template=rewriter_config_template, + max_workspace_size_bytes=max_workspace_size_bytes, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=is_dynamic_op, + maximum_cached_engines=maximum_cached_engines, + use_calibration=use_calibration, + use_function_backup=use_function_backup, + max_batch_size=max_batch_size, + cached_engine_batches=cached_engine_batches) + _check_conversion_params(self._conversion_params) + + def get_rewriter_config(self): + return get_tensorrt_rewriter_config( + conversion_params=self._conversion_params) def finalize_calibration(self): assert self._need_calibration @@ -775,7 +777,7 @@ def finalize_calibration(self): resource_name_input = array_ops.placeholder(dtypes.string) for node in self._converted_graph_def.node: - if node.op == "TRTEngineOp": + if node.op == _TRT_ENGINE_OP_NAME: # Adds the get_serialized_resource_op for the device if not done # before. We only add one such op for each device. # TODO(laigd): What if the device is empty????? @@ -791,11 +793,8 @@ def finalize_calibration(self): calibration_result = self._calibration_sess.run( device_to_get_resource_op_map[node.device], feed_dict={ - container_input: - TrtGraphConverter - ._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME, - resource_name_input: - node.name + container_input: _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME, + resource_name_input: node.name }) node.attr["calibration_data"].s = calibration_result @@ -806,9 +805,106 @@ def save(self, output_saved_model_dir): """Save the converted graph as a SavedModel.""" if self._need_calibration: assert self._calibration_data_collected + super(TrtGraphConverter, self).save(output_saved_model_dir) +class TrtGraphConverterV2(object): + """A converter for TF-TRT transformation for SavedModel in TF 2.0.""" + + def __init__(self, + input_saved_model_dir=None, + input_saved_model_tags=None, + input_saved_model_signature_key=None, + conversion_params=DEFAULT_TRT_CONVERSION_PARAMS): + """Initialize the converter. + + Args: + input_saved_model_dir: the directory to load the SavedModel which contains + the input graph to transforms. Used only when input_graph_def is None. + input_saved_model_tags: list of tags to load the SavedModel. + input_saved_model_signature_key: the key of the signature to optimize the + graph for. + conversion_params: a TrtConversionParams instance. + """ + assert context.executing_eagerly() + _check_trt_version_compatibility() + + self._input_saved_model_dir = input_saved_model_dir + self._input_saved_model_tags = ( + input_saved_model_tags or [tag_constants.SERVING]) + self._input_saved_model_signature_key = ( + input_saved_model_signature_key or + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) + + self._need_calibration = ( + conversion_params.precision_mode == TrtPrecisionMode.INT8 and + conversion_params.use_calibration) + self._conversion_params = conversion_params + _check_conversion_params(self._conversion_params) + self._converted = False + + def _run_conversion(self, meta_graph_def): + """Run Grappler's OptimizeGraph() tool to convert the graph. + + Args: + meta_graph_def: the MetaGraphDef instance to run the optimizations on. + + Returns: + The optimized GraphDef. + """ + rewriter_config = get_tensorrt_rewriter_config( + conversion_params=self._conversion_params) + grappler_session_config = config_pb2.ConfigProto() + grappler_session_config.graph_options.rewrite_options.CopyFrom( + rewriter_config) + return tf_optimizer.OptimizeGraph( + grappler_session_config, meta_graph_def, graph_id=b"tf_graph") + + # TODO(laigd): provide a utility function to optimize a ConcreteFunction and + # use it here (b/124792963). + def convert(self): + """Convert the input SavedModel in 2.0 format.""" + assert not self._converted + self._saved_model = load.load(self._input_saved_model_dir, + self._input_saved_model_tags) + func = self._saved_model.signatures[self._input_saved_model_signature_key] + frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) + grappler_meta_graph_def = saver.export_meta_graph( + graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph) + + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in frozen_func.inputs + frozen_func.outputs: + fetch_collection.node_list.value.append(array.name) + grappler_meta_graph_def.collection_def["train_op"].CopyFrom( + fetch_collection) + + # Run TRT optimizer in Grappler to convert the graph. + converted_graph_def = self._run_conversion(grappler_meta_graph_def) + self._converted_func = wrap_function.function_from_graph_def( + converted_graph_def, [tensor.name for tensor in frozen_func.inputs], + [tensor.name for tensor in frozen_func.outputs]) + + self._converted = True + return self._converted_func + + def save(self, output_saved_model_dir): + """Save the converted SavedModel. + + Args: + output_saved_model_dir: directory to saved the converted SavedModel. + """ + assert self._converted + # Rewrite the signature map using the optimized ConcreteFunction. + signatures = { + key: value for key, value in self._saved_model.signatures.items() + } + signatures[self._input_saved_model_signature_key] = self._converted_func + save.save(self._saved_model, output_saved_model_dir, signatures) + + +# TODO(laigd): use TrtConversionParams here. def create_inference_graph( input_graph_def, outputs, diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index 3d7ebfae8354bc..673c3d66fa2730 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -19,14 +19,17 @@ from __future__ import print_function import os +import tempfile + +import numpy as np from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.compiler.tensorrt import trt_convert -from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import graph_util @@ -35,7 +38,6 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import builder @@ -44,10 +46,11 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils -from tensorflow.python.tools import saved_model_utils from tensorflow.python.saved_model import load from tensorflow.python.saved_model import save +from tensorflow.python.tools import saved_model_utils from tensorflow.python.training.tracking import tracking +from tensorflow.python.util import nest _SAVED_MODEL_SIGNATURE_KEY = "mypredict" @@ -63,8 +66,7 @@ def testGetTensorrtRewriterConfig(self): """Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" if not is_tensorrt_enabled(): return - rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config( - rewriter_config_template=None, + conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( max_batch_size=128, max_workspace_size_bytes=1234, precision_mode="INT8", @@ -72,6 +74,8 @@ def testGetTensorrtRewriterConfig(self): is_dynamic_op=True, maximum_cached_engines=2, cached_engine_batches=[1, 128]) + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( + conversion_params=conversion_params) self.assertEqual(["constfold", "layout", "constfold"], rewriter_cfg.optimizers) self.assertEqual(rewriter_config_pb2.RewriterConfig.ONE, @@ -106,7 +110,8 @@ def _GetConfigProto(self): gpu_options=config_pb2.GPUOptions(allow_growth=True)) return config - def _GetGraph(self): + @classmethod + def _GetGraph(cls, inp, var): """Get the graph for testing.""" # The graph computes (input+1)^2, it looks like: # @@ -119,24 +124,42 @@ def _GetGraph(self): # + # | # output (Identity) + add = inp + var + mul = inp * add + add = mul + add + out = array_ops.identity(add, name="output") + return out + + def _GetModelForV2(self): + + class SimpleModel(tracking.AutoTrackable): + + def __init__(self): + self.v = None + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32) + ]) + def run(self, inp): + if self.v is None: + self.v = variables.Variable([[[1.0]]], dtype=dtypes.float32) + return TrtConvertTest._GetGraph(inp, self.v) + + return SimpleModel() + + def _GetGraphForV1(self): g = ops.Graph() with g.as_default(): with g.device("/GPU:0"): inp = array_ops.placeholder( dtype=dtypes.float32, shape=[None, 1, 1], name="input") - var = variables.VariableV1([[[1.0]]], - dtype=dtypes.float32, - name="v1", - use_resource=False) - add = inp + var.value() - mul = inp * add - add = mul + add - out = array_ops.identity(add, name="output") - return g, var, inp, out + var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1") + out = TrtConvertTest._GetGraph(inp, var) + return g, var, inp, out def _GetGraphDef(self): """Get the graph def for testing.""" - g, var, _, _ = self._GetGraph() + g, var, _, _ = self._GetGraphForV1() with self.session(graph=g, config=self._GetConfigProto()) as sess: sess.run(var.initializer) graph_def = graph_util.convert_variables_to_constants( @@ -145,7 +168,7 @@ def _GetGraphDef(self): self.assertEqual( { "v1": "Const", - "v1/read": "Identity", + "add/ReadVariableOp": "Identity", "input": "Placeholder", "add": "Add", "mul": "Mul", @@ -156,7 +179,7 @@ def _GetGraphDef(self): def _WriteInputSavedModel(self, input_saved_model_dir): """Write the saved model as an input for testing.""" - g, var, inp, out = self._GetGraph() + g, var, inp, out = self._GetGraphForV1() signature_def = signature_def_utils.build_signature_def( inputs={"myinput": utils.build_tensor_info(inp)}, outputs={"myoutput": utils.build_tensor_info(out)}, @@ -183,7 +206,7 @@ def _ConvertGraph(self, input_saved_model_dir=input_saved_model_dir, input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, input_graph_def=None if input_saved_model_dir else self._GetGraphDef(), - nodes_blacklist=["output"], + nodes_blacklist=None if input_saved_model_dir else ["output"], session_config=self._GetConfigProto(), max_batch_size=max_batch_size, max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES, @@ -193,28 +216,23 @@ def _ConvertGraph(self, is_dynamic_op=is_dynamic_op, maximum_cached_engines=maximum_cached_engines, use_function_backup=use_function_backup) - conversion_result = converter.convert() + output_graph_def = converter.convert() - if context.executing_eagerly(): - output_graph_def = conversion_result.graph.as_graph_def() - else: - output_graph_def = conversion_result + if need_calibration: - if need_calibration: - - class CalibrationData(object): + class CalibrationData(object): - def __init__(self): - self._data = 0 + def __init__(self): + self._data = 0 - def next(self): - self._data += 1 - return {"input:0": [[[self._data]]]} + def next(self): + self._data += 1 + return {"input:0": [[[self._data]]]} - output_graph_def = converter.calibrate( - fetch_names=["output:0"], - num_runs=10, - feed_dict_fn=CalibrationData().next) + output_graph_def = converter.calibrate( + fetch_names=["output:0"], + num_runs=10, + feed_dict_fn=CalibrationData().next) if output_saved_model_dir is not None: converter.save(output_saved_model_dir=output_saved_model_dir) @@ -235,31 +253,19 @@ def _TestTrtGraphConverter(self, graph_defs_to_verify = [output_graph_def] if output_saved_model_dir: - if context.executing_eagerly(): - root = load.load(output_saved_model_dir) - saved_model_graph_def = root.signatures[ - _SAVED_MODEL_SIGNATURE_KEY].graph.as_graph_def() - else: - saved_model_graph_def = saved_model_utils.get_meta_graph_def( - output_saved_model_dir, tag_constants.SERVING).graph_def - self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef)) + saved_model_graph_def = saved_model_utils.get_meta_graph_def( + output_saved_model_dir, tag_constants.SERVING).graph_def + self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef) graph_defs_to_verify.append(saved_model_graph_def) for graph_def in graph_defs_to_verify: node_name_to_op = {node.name: node.op for node in graph_def.node} - if context.executing_eagerly(): - # In V2 the actual graph could be inside a function. - for func in graph_def.library.function: - node_name_to_op.update({node.name: node.op for node in func.node_def}) - self.assertIn("TRTEngineOp_0", node_name_to_op) - self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"]) - else: - self.assertEqual( - { - "input": "Placeholder", - "TRTEngineOp_0": "TRTEngineOp", - "output": "Identity" - }, node_name_to_op) + self.assertEqual( + { + "input": "Placeholder", + "TRTEngineOp_0": "TRTEngineOp", + "output": "Identity" + }, node_name_to_op) if need_calibration: trt_engine_nodes = [ @@ -306,39 +312,81 @@ def testTrtGraphConverter_BasicConversion_v2(self): if not is_tensorrt_enabled(): return - # TODO(laigd): we need to use ops like conv2d so Grappler can infer the - # shapes (at least rank) of the tensors, so we're able to build an TRT - # engine in dynamic mode. Currently shape information is not propagate from - # ConcreteFunction to GraphDef, need to investigate and fix it. - class SimpleModel(tracking.AutoTrackable): - - def __init__(self): - self.v = None + np_input = np.random.random_sample([4, 1, 1]).astype(np.float32) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[None, 24, 24, 2], dtype=dtypes.float32) - ]) - def run(self, inp): - if self.v is None: - self.v = variables.Variable([[[[1., 0.5, 4., 6., 0.5, 1.], - [1., 0.5, 1., 1., 0.5, 1.]]]]) - conv = gen_nn_ops.conv2d( - input=inp, filter=self.v, strides=[1, 2, 2, 1], padding="SAME") - identity = array_ops.identity(conv) - return identity - - tmp_dir = self.get_temp_dir() - input_saved_model_dir = os.path.join(tmp_dir, "in_dir1_v2") - root = SimpleModel() + # Create a model and save it. + input_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) + root = self._GetModelForV2() + expected_output = root.run(np_input) save.save(root, input_saved_model_dir, {_SAVED_MODEL_SIGNATURE_KEY: root.run}) - # Convert the SavedModel and verify the result. - output_saved_model_dir = os.path.join(tmp_dir, "out_dir1_v2") - self._TestTrtGraphConverter( + # Run TRT conversion. + converter = trt_convert.TrtGraphConverterV2( input_saved_model_dir=input_saved_model_dir, - output_saved_model_dir=output_saved_model_dir, - is_dynamic_op=True) + input_saved_model_signature_key=_SAVED_MODEL_SIGNATURE_KEY, + conversion_params=trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace( + precision_mode=trt_convert.TrtPrecisionMode.FP32, + is_dynamic_op=True, + maximum_cached_engines=2, + use_function_backup=False)) + converted_concrete_func = converter.convert() + + def _check_trt_ops(graph_def): + trt_op_names = [ + node.name for node in graph_def.node if node.op == "TRTEngineOp" + ] + for func in graph_def.library.function: + for node in func.node_def: + if node.op == "TRTEngineOp": + trt_op_names.append(node.name) + self.assertEqual(1, len(trt_op_names)) + self.assertIn("TRTEngineOp_0", trt_op_names[0]) + + # Verify the converted GraphDef and ConcreteFunction. + self.assertIsInstance(converted_concrete_func, function.ConcreteFunction) + converted_graph_def = converted_concrete_func.graph.as_graph_def() + _check_trt_ops(converted_graph_def) + output_with_trt = converted_concrete_func(ops.convert_to_tensor(np_input)) + self.assertEqual(1, len(output_with_trt)) + self.assertAllClose( + expected_output, output_with_trt[0].numpy(), atol=1e-6, rtol=1e-6) + + # Run the converted ConcreteFunction as a function and make sure it works. + @def_function.function + def wrapper_func(*args, **kwargs): + return nest.flatten(converted_concrete_func(*args, **kwargs)) + + _check_trt_ops( + wrapper_func.get_concrete_function( + tensor_spec.TensorSpec(shape=[None, 1, 1], + dtype=dtypes.float32)).graph.as_graph_def()) + output_with_trt = wrapper_func(np_input) + self.assertEqual(1, len(output_with_trt)) + self.assertAllClose( + expected_output, output_with_trt[0], atol=1e-6, rtol=1e-6) + + # Save the converted model. + output_saved_model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) + converter.save(output_saved_model_dir) + + # Load and verify the converted model. + # + # TODO(laigd): the name of then new input_signature of the + # `root_with_trt.run` function is empty string (originaly was None), + # investigate why. + root_with_trt = load.load(output_saved_model_dir) + # TODO(laigd): `root_with_trt.run` is still using the original graph without + # trt. Consider changing that. + # _check_trt_ops( + # root_with_trt.run.get_concrete_function().graph.as_graph_def()) + converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY] + _check_trt_ops(converted_signature.graph.as_graph_def()) + output_with_trt = converted_signature(ops.convert_to_tensor(np_input)) + # The output of running the converted signature is a dict due to + # compatibility reasons with V1 SavedModel signature mechanism. + output_with_trt = output_with_trt[output_with_trt.keys()[0]] + self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6) def _TestRun(self, sess, @@ -363,7 +411,7 @@ def testTrtGraphConverter_MinimumSegmentSize(self): node_name_to_op = {node.name: node.op for node in output_graph_def.node} self.assertEqual( { - "v1/read": "Const", + "add/ReadVariableOp": "Const", "input": "Placeholder", "add": "Add", "mul": "Mul",