Skip to content

Commit

Permalink
Introduce TrtGraphConverterV2 for TF-TRT conversion in V2, and enhanc…
Browse files Browse the repository at this point in the history
…e the V2 unit test.

PiperOrigin-RevId: 244267523
  • Loading branch information
aaroey authored and tensorflower-gardener committed Apr 18, 2019
1 parent fb49f67 commit 9607281
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 415 deletions.
2 changes: 1 addition & 1 deletion tensorflow/python/compiler/tensorrt/test/base_test.py
Expand Up @@ -144,7 +144,7 @@ def GetConversionParams(self, run_params):
).GetConversionParams(run_params)._replace(
# Disable layout optimizer, since it'll add Transpose(Const, Const) to
# the graph and breaks the conversion check.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())


class SimpleMultiEnginesTest2(trt_test.TfTrtIntegrationTestBase):
Expand Down
Expand Up @@ -124,7 +124,7 @@ def GetConversionParams(self, run_params):
maximum_cached_engines=1,
# Disable layout optimizer, since it will convert BiasAdd with NHWC
# format to NCHW format under four dimentional input.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())

def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
Expand Down
Expand Up @@ -85,7 +85,7 @@ def GetConversionParams(self, run_params):
maximum_cached_engines=10,
# Disable layout optimizer, since it will convert BiasAdd with NHWC
# format to NCHW format under four dimentional input.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())

def ExpectedEnginesToBuild(self, run_params):
return ["TRTEngineOp_0"]
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/compiler/tensorrt/test/int32_test.py
Expand Up @@ -65,7 +65,7 @@ def GetConversionParams(self, run_params):
maximum_cached_engines=1,
# Disable layout optimizer, since it will convert BiasAdd with NHWC
# format to NCHW format under four dimentional input.
rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
rewriter_config_template=trt_test.OptimizerDisabledRewriterConfig())

def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
Expand Down
Expand Up @@ -56,12 +56,6 @@
"use_calibration"
])

ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
"cached_engine_batches", "rewriter_config", "use_calibration"
])

PRECISION_MODES = ["FP32", "FP16", "INT8"]


Expand Down Expand Up @@ -163,27 +157,30 @@ def GetParams(self):
raise NotImplementedError()

def GetConversionParams(self, run_params):
"""Return a ConversionParams for test."""
"""Return a TrtConversionParams for test."""
batch_list = []
for dims_list in self._GetParamsCached().input_dims:
assert dims_list
# Each list of shapes should have same batch size.
input_batches = [dims[0] for dims in dims_list]
assert max(input_batches) == min(input_batches)
batch_list.append(input_batches[0])
return ConversionParams(
conversion_params = trt_convert.TrtConversionParams(
# We use the minimum of all the batch sizes, so when multiple different
# input shapes are provided it'll always create new engines in the
# cache, and we can therefore test the cache behavior.
max_batch_size=min(batch_list),
rewriter_config_template=None,
max_workspace_size_bytes=1 << 25,
precision_mode=run_params.precision_mode,
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
cached_engine_batches=None,
rewriter_config=None,
use_calibration=run_params.use_calibration)
use_calibration=run_params.use_calibration,
use_function_backup=False,
max_batch_size=min(batch_list),
cached_engine_batches=None)
return conversion_params._replace(
use_function_backup=IsQuantizationWithCalibration(conversion_params))

def ShouldRunTest(self, run_params):
"""Whether to run the test."""
Expand Down Expand Up @@ -218,24 +215,13 @@ def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
conversion_params = self.GetConversionParams(run_params)
if graph_state == GraphState.INFERENCE and run_params.use_optimizer:
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
conversion_params.rewriter_config,
conversion_params.max_batch_size,
conversion_params.max_workspace_size_bytes,
conversion_params.precision_mode,
conversion_params.minimum_segment_size,
conversion_params.is_dynamic_op,
conversion_params.maximum_cached_engines,
conversion_params.cached_engine_batches,
conversion_params.use_calibration,
use_function_backup=IsQuantizationWithCalibration(conversion_params))

rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(conversion_params)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
if conversion_params.rewriter_config is not None:
if conversion_params.rewriter_config_template is not None:
graph_options.rewrite_options.CopyFrom(
conversion_params.rewriter_config)
conversion_params.rewriter_config_template)

config = config_pb2.ConfigProto(
gpu_options=self._GetGPUOptions(), graph_options=graph_options)
Expand Down Expand Up @@ -310,7 +296,7 @@ def _CreateConverter(self, gdef, session_config, conversion_params):
maximum_cached_engines=conversion_params.maximum_cached_engines,
cached_engine_batches=conversion_params.cached_engine_batches,
use_calibration=conversion_params.use_calibration,
use_function_backup=IsQuantizationWithCalibration(conversion_params))
use_function_backup=conversion_params.use_function_backup)
return converter

def _GetCalibratedInferGraph(self, run_params, gdef, inputs_data):
Expand Down

0 comments on commit 9607281

Please sign in to comment.