Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validators for dynamic shapes in converter registration #2796

Merged
merged 13 commits into from
May 16, 2024
20 changes: 19 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ConverterSupport:

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
dynamic: bool = False
peri044 marked this conversation as resolved.
Show resolved Hide resolved


# Dictionary representing Dynamo aten-only converters
Expand All @@ -88,9 +89,11 @@ class ConverterSupport:

def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
dynamic: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
"""Decorator for Dynamo TensorRT Converter

Expand All @@ -116,14 +119,17 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat

# If no capability_validator function is specified, use the default function - always return true
if capability_validator is None:
converter_support = ConverterSupport(converter_implementation=converter)
converter_support = ConverterSupport(
converter_implementation=converter, dynamic=dynamic
)
else:
assert callable(
capability_validator
), "Argument checking function must be callable"
converter_support = ConverterSupport(
converter_implementation=converter,
capability_validator=capability_validator,
dynamic=dynamic,
)

# OpOverloadPackets are only valid if they have a single overload, or
Expand Down Expand Up @@ -323,6 +329,18 @@ def __getitem__(

if isinstance(converters, (list, tuple)):
for candidate in converters:
# TODO: Importing this here avoids circular import issue. One potential fix is moving this function into _ConverterRegistry file.
peri044 marked this conversation as resolved.
Show resolved Hide resolved
from torch_tensorrt.dynamo.conversion.converter_utils import (
dynamic_unsupported,
)

has_static_inputs = dynamic_unsupported(node)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
# If there are dynamic inputs but the converter doesn't support it explicitly, throw a warning.
if not has_static_inputs and not candidate.dynamic:
peri044 marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
f"The converter for node {node.target} received dynamic shaped inputs but the static version of the converter is being used. Please report this issue at https://github.com/pytorch/TensorRT/issues"
)

if candidate.capability_validator(node):
return (
candidate.converter_implementation,
Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def aten_ops_grid(
)


@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default, dynamic=True)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
def aten_ops_relu(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -2080,7 +2080,9 @@ def conv_param_validator(conv_node: Node) -> bool:


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
torch.ops.aten.convolution.default,
capability_validator=conv_param_validator,
dynamic=True,
peri044 marked this conversation as resolved.
Show resolved Hide resolved
)
@enforce_tensor_types(
{
Expand Down
13 changes: 11 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@ def _dynamic_unsupported(

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
return getattr(
_has_symbolic_sizes_strides = getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))
)

is_shape_dynamic = False
if "val" in subnode.meta:
shape = subnode.meta["val"].size()
is_shape_dynamic = any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
)

return _has_symbolic_sizes_strides or is_shape_dynamic

# Check node value itself
if arg_positions_to_check is None and _is_subnode_dynamic(node):
Expand Down
Loading