Skip to content

Commit

Permalink
[TF-TRT] Various Cleanups & Python Debugging Assertion Improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
DEKHTIARJonathan committed Aug 12, 2022
1 parent cd17810 commit a4d15ef
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 24 deletions.
Expand Up @@ -51,7 +51,6 @@
from tensorflow.python.saved_model import utils
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.trackable import autotrackable
from tensorflow.python.trackable import resource
from tensorflow.python.util import nest

logging.get_logger().propagate = False
Expand Down Expand Up @@ -169,7 +168,10 @@ def tearDown(self):

def _GetTensorSpec(self, shape, mask, dtype, name):
# Set dimension i to None if mask[i] == False
assert len(shape) == len(mask)
assert len(shape) == len(mask), (
f"len(shape): {len(shape)} == len(mask): {len(mask)}"
)

new_shape = [s if m else None for s, m in zip(shape, mask)]
return tensor_spec.TensorSpec(new_shape, dtype, name)

Expand All @@ -178,7 +180,7 @@ def BuildParams(self, graph_fn, dtype, input_shapes, output_shapes):
The input_shapes and output_shapes arguments are known (static) shapes that
can be used to generate test data. To define the model, we also specify
corresponding input/output TensoSpecs. These are defined using the shape
corresponding input/output TensorSpecs. These are defined using the shape
arguments. For each input tensor we define:
input_spec = [None] + input_shape[1:]
Expand Down Expand Up @@ -234,16 +236,28 @@ def BuildParamsWithMask(self, graph_fn, dtype, input_shapes, output_shapes,
def _ValidateShapes(shapes):
# Make sure all the shapes are fully specified.
for shape in shapes:
assert all(shape)
assert all(shape), f"Shape unspecified: {shape}"

_ValidateShapes(input_shapes)
_ValidateShapes(output_shapes)

assert len(input_mask) == len(input_shapes)
assert len(output_mask) == len(output_shapes)
assert len(input_mask) == len(input_shapes), (
f"Inconsistent input_mask and input_shapes: len({input_mask}) != "
f"len({input_shapes})."
)
assert len(output_mask) == len(output_shapes), (
f"Inconsistent output_mask and output_shapes: len({output_mask}) != "
f"len({output_shapes})."
)
for extra_in_shape, extra_out_shape in zip(extra_inputs, extra_outputs):
assert len(input_shapes) == len(extra_in_shape)
assert len(output_shapes) == len(extra_out_shape)
assert len(input_shapes) == len(extra_in_shape), (
f"Inconsistent input_shapes and extra_in_shape: len({input_shapes}) "
f"!= len({extra_in_shape})."
)
assert len(output_shapes) == len(extra_out_shape), (
f"Inconsistent output_shapes and extra_out_shape: "
f"len({output_shapes}) != len({extra_out_shape})."
)

return TfTrtIntegrationTestParams(
graph_fn=graph_fn,
Expand Down Expand Up @@ -284,10 +298,13 @@ def GetMaxBatchSize(self, run_params):
return None
batch_list = []
for dims_list in self._GetParamsCached().input_dims:
assert dims_list
assert dims_list, f"Expect non-empty `dim_list` but got: {dims_list}"
# Each list of shapes should have same batch size.
input_batches = [dims[0] for dims in dims_list]
assert max(input_batches) == min(input_batches)
assert max(input_batches) == min(input_batches), (
f"Inconsistent batch_size: max({input_batches}) != "
f"min({input_batches})."
)
batch_list.append(input_batches[0])
return max(batch_list)

Expand Down Expand Up @@ -403,7 +420,9 @@ def _RunGraphV2(self, saved_model_dir, inputs_data, graph_state, num_runs=2):
for i in range(len(params.input_specs))
}
new_val = func(**feed_dict)
assert isinstance(new_val, dict)
assert isinstance(new_val, dict), (
f"Invalid type for `new_val`, expected `dict`. Got: {type(new_val)}."
)
# The key of the output map is always like output_i.
new_val = [new_val[key] for key in sorted(new_val)]
# Each element is an eager Tensor, and accessing individual elements is
Expand All @@ -429,7 +448,10 @@ def _RunGraph(self,
num_runs=2):
params = self._GetParamsCached()
for data in inputs_data:
assert len(params.input_specs) == len(data)
assert len(params.input_specs) == len(data), (
f"Inconsistent params.input_specs and data: "
f"len({params.input_specs}) != len({data})."
)

if run_params.is_v2:
results = self._RunGraphV2(saved_model_dir, inputs_data, graph_state,
Expand Down Expand Up @@ -475,14 +497,21 @@ def _GetCalibratedInferGraph(self, run_params, saved_model_dir, inputs_data):
"""Return trt converted graphdef in INT8 mode."""
conversion_params = self.GetConversionParams(run_params)
logging.info(conversion_params)
assert conversion_params.precision_mode == "INT8"
assert run_params.dynamic_engine
assert conversion_params.maximum_cached_engines == 1
assert conversion_params.use_calibration
assert conversion_params.precision_mode == "INT8", (
f"Incorrect precision mode, expected INT8 but got: "
f"{conversion_params.precision_mode}."
)
assert run_params.dynamic_engine, "dynamic_engine parameter must be True."
assert conversion_params.maximum_cached_engines == 1, (
f"maximum_cached_engines: {conversion_params.maximum_cached_engines} == 1"
)
assert conversion_params.use_calibration, "use_calibration must be True."

# We only support calibrating single engine.
# TODO(aaroey): fix this.
assert len(inputs_data) == 1
assert len(inputs_data) == 1, (
f"len(inputs_data): {len(inputs_data)} == 1"
)

converter = self._CreateConverter(run_params, saved_model_dir,
conversion_params)
Expand Down Expand Up @@ -598,10 +627,15 @@ def _RemoveGraphSequenceNumberImpl(self, value, expecting_prefix):
if isinstance(value, str):
match = re.search(r"TRTEngineOp_\d{3,}_", value)
has_prefix = match and value.startswith(match.group(0))
assert (not expecting_prefix) or has_prefix
assert (not expecting_prefix) or has_prefix, (
f"Expect (not expecting_prefix) or has_prefix but got: "
f"- expecting_prefix = {expecting_prefix}\n- has_prefix = {has_prefix}"
)
if has_prefix:
parts = value.split("_", maxsplit=2)
assert len(parts) == 3
assert len(parts) == 3, (
f"Incorrect `parts` of length == 3, but got: len({parts})."
)
return parts[0] + "_" + parts[2]
return value
elif isinstance(value, collections.abc.Iterable):
Expand Down Expand Up @@ -812,7 +846,10 @@ def _GetGraphDef(self, run_params, gdef_or_saved_model_dir):
return gdef
return saved_model_utils.get_meta_graph_def(
gdef_or_saved_model_dir, tag_constants.SERVING).graph_def
assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef)
assert isinstance(gdef_or_saved_model_dir, graph_pb2.GraphDef), (
f"Incorrect `gdef_or_saved_model_dir` type, expected "
f"`graph_pb2.GraphDef`, but got: {type(gdef_or_saved_model_dir)}."
)
return gdef_or_saved_model_dir

def _VerifyGraphDefV1(self, run_params, original_gdef, gdef_to_verify,
Expand Down Expand Up @@ -995,7 +1032,10 @@ def RunTest(self, run_params):
inputs_data = []
input_specs = self._GetParamsCached().input_specs
for dim_list in self._GetParamsCached().input_dims:
assert len(input_specs) == len(dim_list)
assert len(input_specs) == len(dim_list), (
f"Inconsistent input_specs and dim_list: len({input_specs}) != "
f"len({dim_list})."
)
current_input_data = []
for spec, np_shape in zip(input_specs, dim_list):
np_dtype = spec.dtype.as_numpy_dtype()
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/python/compiler/tensorrt/utils.py
Expand Up @@ -18,7 +18,7 @@
import os
import re

from distutils import version
from packaging import version

from tensorflow.compiler.tf2tensorrt import _pywrap_py_utils
from tensorflow.core.protobuf import rewriter_config_pb2
Expand Down Expand Up @@ -60,8 +60,8 @@ def version_tuple_to_string(ver_tuple):


def _is_tensorrt_version_greater_equal(trt_ver, target_ver):
trt_ver = version.LooseVersion(version_tuple_to_string(trt_ver))
target_ver = version.LooseVersion(version_tuple_to_string(target_ver))
trt_ver = version.Version(version_tuple_to_string(trt_ver))
target_ver = version.Version(version_tuple_to_string(target_ver))

return trt_ver >= target_ver

Expand Down

0 comments on commit a4d15ef

Please sign in to comment.