From 7a59e639addf18875716ab3c7223d1d5d20e11de Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 29 Apr 2024 15:45:43 -0700 Subject: [PATCH 01/10] feat: Add validators for dynamic shapes in converter registration --- .../dynamo/conversion/_ConverterRegistry.py | 20 ++++++++++++++++++- .../dynamo/conversion/aten_ops_converters.py | 6 ++++-- .../dynamo/conversion/converter_utils.py | 13 ++++++++++-- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 050a62ef3e..9967198772 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -79,6 +79,7 @@ class ConverterSupport: converter_implementation: ConverterImplSignature capability_validator: Callable[[Node], bool] = field(default=lambda node: True) + dynamic: bool = False # Dictionary representing Dynamo aten-only converters @@ -88,9 +89,11 @@ class ConverterSupport: def dynamo_tensorrt_converter( key: Target, + *, enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, + dynamic: bool = False, ) -> Callable[[ConverterImplSignature], ConverterImplSignature]: """Decorator for Dynamo TensorRT Converter @@ -116,7 +119,9 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat # If no capability_validator function is specified, use the default function - always return true if capability_validator is None: - converter_support = ConverterSupport(converter_implementation=converter) + converter_support = ConverterSupport( + converter_implementation=converter, dynamic=dynamic + ) else: assert callable( capability_validator @@ -124,6 +129,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat converter_support = ConverterSupport( converter_implementation=converter, capability_validator=capability_validator, + dynamic=dynamic, ) # OpOverloadPackets are only valid if they have a single overload, or @@ -323,6 +329,18 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: + # TODO: Importing this here avoids circular import issue. One potential fix is moving this function into _ConverterRegistry file. + from torch_tensorrt.dynamo.conversion.converter_utils import ( + dynamic_unsupported, + ) + + has_static_inputs = dynamic_unsupported(node) + # If there are dynamic inputs but the converter doesn't support it explicitly, throw a warning. + if not has_static_inputs and not candidate.dynamic: + logger.warning( + f"The converter for node {node.target} received dynamic shaped inputs but the static version of the converter is being used. Please report this issue at https://github.com/pytorch/TensorRT/issues" + ) + if candidate.capability_validator(node): return ( candidate.converter_implementation, diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 72998e1917..e6f14c34e7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -358,7 +358,7 @@ def aten_ops_grid( ) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.relu.default, dynamic=True) def aten_ops_relu( ctx: ConversionContext, target: Target, @@ -2080,7 +2080,9 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.convolution.default, capability_validator=conv_param_validator + torch.ops.aten.convolution.default, + capability_validator=conv_param_validator, + dynamic=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 7e55110459..44a08fa8c8 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -82,9 +82,18 @@ def _dynamic_unsupported( def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: """Checks if a node itself has Dynamic properties""" - return getattr( + _has_symbolic_sizes_strides = getattr( subnode.meta["val"], "_has_symbolic_sizes_strides", False - ) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool)) + ) + + is_shape_dynamic = False + if "val" in subnode.meta: + shape = subnode.meta["val"].size() + is_shape_dynamic = any( + isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape + ) + + return _has_symbolic_sizes_strides or is_shape_dynamic # Check node value itself if arg_positions_to_check is None and _is_subnode_dynamic(node): From f55d41ae02913cf76477bbb550f573aa365597b2 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 29 Apr 2024 19:41:10 -0700 Subject: [PATCH 02/10] chore: updates --- .../dynamo/conversion/_ConverterRegistry.py | 91 ++++++++++++++++--- .../dynamo/conversion/aten_ops_converters.py | 12 +-- .../dynamo/conversion/converter_utils.py | 66 -------------- 3 files changed, 83 insertions(+), 86 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 9967198772..9deee3d250 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import logging from dataclasses import dataclass, field from enum import Enum, auto @@ -17,6 +18,8 @@ cast, ) +import torch +from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -75,11 +78,12 @@ class ConverterSupport: capability_validator: Function which takes in a Node and returns a bool indicating whether that node can be supported by its companion converter. Note that this function must not modify the node or its graph + supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs. """ converter_implementation: ConverterImplSignature capability_validator: Callable[[Node], bool] = field(default=lambda node: True) - dynamic: bool = False + supports_dynamic_shapes: bool = False # Dictionary representing Dynamo aten-only converters @@ -87,13 +91,78 @@ class ConverterSupport: DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {} +def has_dynamic_shapes(node: torch.fx.Node) -> bool: + """Returns True if a node has dynamic args, kwargs, or outputs""" + return _has_dynamic_shapes(node=node) + + +def has_dynamic_shapes_in_args( + arg_positions_to_check: Optional[List[int]] = None, +) -> Callable[[torch.fx.Node], bool]: + """Returns True if a node has dynamic inputs in node.args at specified positions""" + return functools.partial( + _has_dynamic_shapes, arg_positions_to_check=arg_positions_to_check + ) + + +def _has_dynamic_shapes( + node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None +) -> bool: + # Validate that none of the inputs to the node have Dynamic shapes + assert isinstance( + node, torch.fx.Node + ), "Inputs to validator functions must be FX Nodes" + + def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: + """Checks if a node itself has Dynamic properties""" + _has_symbolic_sizes_strides, is_shape_dynamic = False, False + if "val" in subnode.meta: + _has_symbolic_sizes_strides = getattr( + subnode.meta["val"], "_has_symbolic_sizes_strides", False + ) + + shape = subnode.meta["val"].size() + is_shape_dynamic = any( + isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape + ) + + return _has_symbolic_sizes_strides or is_shape_dynamic + + # Check node value itself + if arg_positions_to_check is None and _is_subnode_dynamic(node): + return True + + # Check node arguments individually + if arg_positions_to_check is None and any( + _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) + ): + return True + # Check specific arg positions if the caller has specified positions to check + elif arg_positions_to_check is not None and any( + _is_subnode_dynamic(node.args[i]) + for i in arg_positions_to_check + if isinstance(node.args[i], torch.fx.Node) + ): + return True + + # Check node keyword arguments individually + if arg_positions_to_check is None and any( + _is_subnode_dynamic(kwarg) + for kwarg in node.kwargs.values() + if isinstance(kwarg, torch.fx.Node) + ): + return True + + return False + + def dynamo_tensorrt_converter( key: Target, *, enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, - dynamic: bool = False, + supports_dynamic_shapes: bool = False, ) -> Callable[[ConverterImplSignature], ConverterImplSignature]: """Decorator for Dynamo TensorRT Converter @@ -120,7 +189,8 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat # If no capability_validator function is specified, use the default function - always return true if capability_validator is None: converter_support = ConverterSupport( - converter_implementation=converter, dynamic=dynamic + converter_implementation=converter, + supports_dynamic_shapes=supports_dynamic_shapes, ) else: assert callable( @@ -129,7 +199,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat converter_support = ConverterSupport( converter_implementation=converter, capability_validator=capability_validator, - dynamic=dynamic, + supports_dynamic_shapes=supports_dynamic_shapes, ) # OpOverloadPackets are only valid if they have a single overload, or @@ -329,16 +399,13 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: - # TODO: Importing this here avoids circular import issue. One potential fix is moving this function into _ConverterRegistry file. - from torch_tensorrt.dynamo.conversion.converter_utils import ( - dynamic_unsupported, - ) - - has_static_inputs = dynamic_unsupported(node) # If there are dynamic inputs but the converter doesn't support it explicitly, throw a warning. - if not has_static_inputs and not candidate.dynamic: + if ( + not candidate.supports_dynamic_shapes + and has_dynamic_shapes(node) + ): logger.warning( - f"The converter for node {node.target} received dynamic shaped inputs but the static version of the converter is being used. Please report this issue at https://github.com/pytorch/TensorRT/issues" + f"The converter for node {node.target} received dynamic shaped inputs although it was designed for static inputs. This shouldn't likely cause issues unless there are some dimensions which are dynamic (excluding the batch). If you encounter any issues, please post at https://github.com/pytorch/TensorRT/issues" ) if candidate.capability_validator(node): diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e6f14c34e7..f35d1ec444 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -12,7 +12,6 @@ dynamo_tensorrt_converter, ) from torch_tensorrt.dynamo.conversion.converter_utils import ( - dynamic_unsupported_with_args, enforce_tensor_types, is_only_operator_on_placeholder, ) @@ -358,7 +357,7 @@ def aten_ops_grid( ) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default, dynamic=True) +@dynamo_tensorrt_converter(torch.ops.aten.relu.default, supports_dynamic_shapes=True) def aten_ops_relu( ctx: ConversionContext, target: Target, @@ -645,14 +644,11 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) -) -@dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) + torch.ops.aten.split.Tensor, ) +@dynamo_tensorrt_converter(torch.ops.aten.split.sizes) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, - capability_validator=dynamic_unsupported_with_args([1]), ) def aten_ops_split( ctx: ConversionContext, @@ -2082,7 +2078,7 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, capability_validator=conv_param_validator, - dynamic=True, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 44a08fa8c8..949e047e38 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -6,7 +6,6 @@ import numpy as np import tensorrt as trt import torch -from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -58,71 +57,6 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: ) -def dynamic_unsupported(node: torch.fx.Node) -> bool: - """Validates that a node has no dynamic args, kwargs, or outputs""" - return _dynamic_unsupported(node=node) - - -def dynamic_unsupported_with_args( - arg_positions_to_check: Optional[List[int]] = None, -) -> Callable[[torch.fx.Node], bool]: - """Returns a validator that a node has no dynamic args at specific positions""" - return functools.partial( - _dynamic_unsupported, arg_positions_to_check=arg_positions_to_check - ) - - -def _dynamic_unsupported( - node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None -) -> bool: - # Validate that none of the inputs to the node have Dynamic shapes - assert isinstance( - node, torch.fx.Node - ), "Inputs to validator functions must be FX Nodes" - - def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: - """Checks if a node itself has Dynamic properties""" - _has_symbolic_sizes_strides = getattr( - subnode.meta["val"], "_has_symbolic_sizes_strides", False - ) - - is_shape_dynamic = False - if "val" in subnode.meta: - shape = subnode.meta["val"].size() - is_shape_dynamic = any( - isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape - ) - - return _has_symbolic_sizes_strides or is_shape_dynamic - - # Check node value itself - if arg_positions_to_check is None and _is_subnode_dynamic(node): - return False - - # Check node arguments individually - if arg_positions_to_check is None and any( - _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) - ): - return False - # Check specific arg positions if the caller has specified positions to check - elif arg_positions_to_check is not None and any( - _is_subnode_dynamic(node.args[i]) - for i in arg_positions_to_check - if isinstance(node.args[i], torch.fx.Node) - ): - return False - - # Check node keyword arguments individually - if arg_positions_to_check is None and any( - _is_subnode_dynamic(kwarg) - for kwarg in node.kwargs.values() - if isinstance(kwarg, torch.fx.Node) - ): - return False - - return True - - def cast_trt_tensor( ctx: ConversionContext, input_val: TRTTensor, From 87da1c183c61f972f535f9549d6a42df62847c89 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 30 Apr 2024 10:43:45 -0700 Subject: [PATCH 03/10] chore: updates --- .../dynamo/conversion/_ConverterRegistry.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 9deee3d250..81a079145a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -120,11 +120,22 @@ def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: _has_symbolic_sizes_strides = getattr( subnode.meta["val"], "_has_symbolic_sizes_strides", False ) - - shape = subnode.meta["val"].size() - is_shape_dynamic = any( - isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape - ) + meta_val = subnode.meta["val"] + if isinstance(meta_val, (list, tuple)): + for val in meta_val: + shape = val.size() + if any( + isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape + ): + is_shape_dynamic = True + break + elif isinstance(meta_val, (SymFloat, SymInt, SymBool)): + is_shape_dynamic = True + else: + shape = subnode.meta["val"].size() + is_shape_dynamic = any( + isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape + ) return _has_symbolic_sizes_strides or is_shape_dynamic From e3e7927de406ee38c89d36245f2dfa2185b53fed Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 May 2024 15:57:49 -0700 Subject: [PATCH 04/10] chore: updates --- .../dynamo/conversion/_ConverterRegistry.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 4bd4706f2b..0484d04442 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -18,6 +18,7 @@ cast, ) +import tensorrt as trt import torch from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket @@ -25,8 +26,6 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS -import tensorrt as trt - logger = logging.getLogger(__name__) LegacyConverterImplSignature = Callable[ @@ -411,22 +410,28 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: - # If there are dynamic inputs but the converter doesn't support it explicitly, throw a warning. if ( - not candidate.supports_dynamic_shapes + candidate.capability_validator(node) and has_dynamic_shapes(node) + and candidate.supports_dynamic_shapes ): - logger.warning( - f"The converter for node {node.target} received dynamic shaped inputs although it was designed for static inputs. This shouldn't likely cause issues unless there are some dimensions which are dynamic (excluding the batch). If you encounter any issues, please post at https://github.com/pytorch/TensorRT/issues" + # If node has dynamic inputs and the converter supports dynamic shapes, it is enabled + return ( + candidate.converter_implementation, + calling_convention, ) - - if candidate.capability_validator(node): + elif candidate.capability_validator( + node + ) and not has_dynamic_shapes(node): + # For static shapes all converters are turned on based on capability_validator check return ( candidate.converter_implementation, calling_convention, ) else: - return converters, calling_convention + # Assuming FX converters don't have dynamic shapes supported + if not has_dynamic_shapes(node): + return converters, calling_convention raise KeyError( f"None of the converter registries have a validated entry for {key}, with node {node}" From 8ec68dae49635bfdc10f9c0d484fd6abfd6513a1 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 May 2024 19:22:04 -0700 Subject: [PATCH 05/10] chore: address failures and implement flag to enable all converters --- py/torch_tensorrt/dynamo/_compiler.py | 5 +++ py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 3 ++ .../dynamo/conversion/_ConverterRegistry.py | 16 +++++--- .../dynamo/conversion/aten_ops_converters.py | 39 ++++++++++++------- 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 32b0ca65d7..1a2297f84c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -47,6 +47,7 @@ def compile( *, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, disable_tf32: bool = _defaults.DISABLE_TF32, + disable_dynamic_converter_checks: bool = _defaults.DISABLE_DYNAMIC_CONVERTER_CHECKS, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] @@ -189,6 +190,7 @@ def compile( ), "debug": debug, "device": device, + "disable_dynamic_converter_checks": disable_dynamic_converter_checks, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": ( @@ -239,6 +241,9 @@ def compile_module( """ dryrun_tracker = DryRunTracker() + # Disable dynamic_shapes support checks for converters + CONVERTERS.disable_dynamic_checks(settings.disable_dynamic_converter_checks) + # Set torch-executed ops CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 97430137c0..a57c20dcd6 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -6,6 +6,7 @@ DEBUG = False DEVICE = None DISABLE_TF32 = False +DISABLE_DYNAMIC_CONVERTER_CHECKS = False DLA_LOCAL_DRAM_SIZE = 1073741824 DLA_GLOBAL_DRAM_SIZE = 536870912 DLA_SRAM_SIZE = 1048576 diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 9592bc1fd5..1c29344d49 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -6,6 +6,7 @@ from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( DEBUG, + DISABLE_DYNAMIC_CONVERTER_CHECKS, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -57,6 +58,7 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path + disable_dynamic_converter_checks (bool): Setting this to true enables the converters work for both dynamic and static shapes. disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights refit (bool): Whether to build a refittable engine @@ -87,6 +89,7 @@ class CompilationSettings: device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION disable_tf32: bool = DISABLE_TF32 + disable_dynamic_converter_checks: bool = DISABLE_DYNAMIC_CONVERTER_CHECKS sparse_weights: bool = SPARSE_WEIGHTS refit: bool = REFIT engine_capability: EngineCapability = field( diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 0484d04442..7af25ae552 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -280,6 +280,7 @@ def __init__( ], registry_names: Optional[Sequence[str]] = None, registry_calling_conventions: Optional[Sequence[CallingConvention]] = None, + disable_dynamic_converter_checks: bool = False, ): # Copy reference to each dictionary object into attribute list self.registries = list(registries) @@ -301,9 +302,12 @@ def __init__( ] self.disallowed_targets: Collection[Target] = set() - + self.disable_dynamic_converter_checks = disable_dynamic_converter_checks self.validate_invariants() + def disable_dynamic_checks(self, disable_dynamic_converter_checks: bool) -> None: + self.disable_dynamic_converter_checks = disable_dynamic_converter_checks + def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: self.disallowed_targets = torch_executed_ops @@ -410,10 +414,12 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: - if ( - candidate.capability_validator(node) - and has_dynamic_shapes(node) - and candidate.supports_dynamic_shapes + if candidate.capability_validator(node) and ( + self.disable_dynamic_converter_checks + or ( + has_dynamic_shapes(node) + and candidate.supports_dynamic_shapes + ) ): # If node has dynamic inputs and the converter supports dynamic shapes, it is enabled return ( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5c851d75b5..c17649540d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -58,8 +58,10 @@ def one_user_validator(node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator ) -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) +@dynamo_tensorrt_converter( + torch.ops.aten.batch_norm.default, supports_dynamic_shapes=True +) +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -93,6 +95,7 @@ def aten_ops_batch_norm( @dynamo_tensorrt_converter( torch.ops.aten._native_batch_norm_legit_no_training.default, capability_validator=one_user_validator, + supports_dynamic_shapes=True, ) def aten_ops_batch_norm_legit_no_training( ctx: ConversionContext, @@ -378,7 +381,7 @@ def aten_ops_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) +@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default, supports_dynamic_shapes=True) def aten_ops_sigmoid( ctx: ConversionContext, target: Target, @@ -400,7 +403,7 @@ def aten_ops_sigmoid( 0: (TRTTensor,), } ) -@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int) +@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int, supports_dynamic_shapes=True) def aten_ops_symsize_int( ctx: ConversionContext, target: Target, @@ -1116,8 +1119,8 @@ def aten_ops_min( ) -@dynamo_tensorrt_converter(torch.ops.aten.mean.default) -@dynamo_tensorrt_converter(torch.ops.aten.mean.dim) +@dynamo_tensorrt_converter(torch.ops.aten.mean.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.mean.dim, supports_dynamic_shapes=True) def aten_ops_mean( ctx: ConversionContext, target: Target, @@ -1221,7 +1224,7 @@ def aten_ops_recip( ) -@dynamo_tensorrt_converter(torch.ops.aten.abs.default) +@dynamo_tensorrt_converter(torch.ops.aten.abs.default, supports_dynamic_shapes=True) def aten_ops_abs( ctx: ConversionContext, target: Target, @@ -1568,8 +1571,8 @@ def aten_ops_isnan( ) -@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True) def aten_ops_add( ctx: ConversionContext, target: Target, @@ -2329,13 +2332,19 @@ def max_pool_param_validator(pool_node: Node) -> bool: # Note: MaxPool1d uses max_pool2d as it converts to 2D first. @dynamo_tensorrt_converter( - torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator + torch.ops.aten.max_pool1d.default, + capability_validator=max_pool_param_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator + torch.ops.aten.max_pool2d.default, + capability_validator=max_pool_param_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator + torch.ops.aten.max_pool3d.default, + capability_validator=max_pool_param_validator, + supports_dynamic_shapes=True, ) def aten_ops_max_pool( ctx: ConversionContext, @@ -2380,8 +2389,8 @@ def tensorrt_scaled_dot_product_attention( ) -@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) -@dynamo_tensorrt_converter(torch.ops.aten.view.default) +@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -2490,7 +2499,7 @@ def aten_ops_argmin( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), From 151fc40b824615e8adb1b4e6907d9cfe6352e0ea Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 May 2024 19:27:16 -0700 Subject: [PATCH 06/10] chore: update docstring --- py/torch_tensorrt/dynamo/_compiler.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 1a2297f84c..d189865366 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -107,6 +107,7 @@ def compile( device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + disable_dynamic_converter_checks (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels refit (bool): Enable refitting diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 1c29344d49..f3b6d89dcd 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -58,7 +58,7 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path - disable_dynamic_converter_checks (bool): Setting this to true enables the converters work for both dynamic and static shapes. + disable_dynamic_converter_checks (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights refit (bool): Whether to build a refittable engine From a2ed092e597b25bee95b12560279e6e70038a3d9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 1 May 2024 20:07:07 -0700 Subject: [PATCH 07/10] chore: add testcase --- .../partitioning/test_dynamic_partitioning.py | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 tests/py/dynamo/partitioning/test_dynamic_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_dynamic_partitioning.py b/tests/py/dynamo/partitioning/test_dynamic_partitioning.py new file mode 100644 index 0000000000..08283ebffb --- /dev/null +++ b/tests/py/dynamo/partitioning/test_dynamic_partitioning.py @@ -0,0 +1,109 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning + +from ..testing_utilities import lower_graph_testing + +# This testcase assumes that torch.ops.aten.clamp.default converter doesn't support +# dynamic shapes. One should remove this testcase when the support is added. +# This testcase tests if the graph is partitioned correctly into a TRT segment +# and a Pytorch segment when the torch.ops.aten.clamp.default converter gets disabled +# due to lack of dynamic shape support. + + +class TestDynamicPartitioning(TestCase): + def test_partition_dynamic_clamp(self): + class Clamp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return torch.ops.aten.clamp.default(x, min=2.5, max=6.5) + + model = Clamp().eval().cuda() + trt_model = torch_tensorrt.compile( + model, + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 3, 8, 8), + opt_shape=(4, 3, 8, 8), + max_shape=(8, 3, 8, 8), + dtype=torch.float32, + name="x", + ) + ], + dryrun=True, + min_block_size=1, + ) + trt_segments, pyt_segments = 0, 0 + for submod in list(trt_model.named_children()): + if "_run_on_acc" in submod[0]: + trt_segments += 1 + elif "_run_on_gpu" in submod[0]: + pyt_segments += 1 + + self.assertEquals( + trt_segments, + 1, + f"Number of TRT segments should be 1 but got {trt_segments}", + ) + self.assertEquals( + pyt_segments, + 1, + f"Number of PyTorch segments should be 1 but got {pyt_segments}", + ) + + def test_disable_dynamic_converter_checks(self): + class Clamp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return torch.ops.aten.clamp.default(x, min=2.5, max=6.5) + + model = Clamp().eval().cuda() + trt_model = torch_tensorrt.compile( + model, + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 3, 8, 8), + opt_shape=(4, 3, 8, 8), + max_shape=(8, 3, 8, 8), + dtype=torch.float32, + name="x", + ) + ], + dryrun=True, + disable_dynamic_converter_checks=True, + min_block_size=1, + ) + + trt_segments, pyt_segments = 0, 0 + for submod in list(trt_model.named_children()): + if "_run_on_acc" in submod[0]: + trt_segments += 1 + elif "_run_on_gpu" in submod[0]: + pyt_segments += 1 + + self.assertEquals( + trt_segments, + 1, + f"Number of TRT segments should be 2 but got {trt_segments}", + ) + self.assertEquals( + pyt_segments, + 0, + f"Number of PyTorch segments should be 0 but got {pyt_segments}", + ) + + +if __name__ == "__main__": + run_tests() From c1f5d15209cef0e1bff1f4ce23b5456cc4a593af Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 2 May 2024 10:26:30 -0700 Subject: [PATCH 08/10] chore: updates --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 4 +++- py/torch_tensorrt/dynamo/conversion/ops_evaluators.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c17649540d..6b83c37eeb 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -56,7 +56,9 @@ def one_user_validator(node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator + torch.ops.aten.native_batch_norm.default, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.batch_norm.default, supports_dynamic_shapes=True diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index f83e0e5008..ea2f1c4d89 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -23,7 +23,11 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators -@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +@dynamo_tensorrt_converter( + operator.getitem, + capability_validator=getitem_validator, + supports_dynamic_shapes=True, +) @dynamo_tensorrt_converter(torch.ops.aten.detach.default) def generic_evaluator( ctx: ConversionContext, From 649b79da38b24f9981372730210ed2ee5efe9fe9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 2 May 2024 12:08:27 -0700 Subject: [PATCH 09/10] chore: rename disable_dynamic_converter_checks to assume_dynamic_shape_support --- py/torch_tensorrt/dynamo/_compiler.py | 16 +++++++++++----- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_settings.py | 6 +++--- .../dynamo/conversion/_ConverterRegistry.py | 10 +++++----- .../partitioning/test_dynamic_partitioning.py | 4 ++-- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d189865366..c3cca50f65 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -47,7 +47,7 @@ def compile( *, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, disable_tf32: bool = _defaults.DISABLE_TF32, - disable_dynamic_converter_checks: bool = _defaults.DISABLE_DYNAMIC_CONVERTER_CHECKS, + assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] @@ -107,7 +107,7 @@ def compile( device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas - disable_dynamic_converter_checks (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False + assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels refit (bool): Enable refitting @@ -191,7 +191,7 @@ def compile( ), "debug": debug, "device": device, - "disable_dynamic_converter_checks": disable_dynamic_converter_checks, + "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": ( @@ -242,8 +242,8 @@ def compile_module( """ dryrun_tracker = DryRunTracker() - # Disable dynamic_shapes support checks for converters - CONVERTERS.disable_dynamic_checks(settings.disable_dynamic_converter_checks) + # Assume converters support dynamic shapes and disable validation + CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) # Set torch-executed ops CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) @@ -449,6 +449,7 @@ def convert_module_to_trt_engine( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, debug: bool = _defaults.DEBUG, + assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, torch_executed_ops: Optional[Set[str]] = None, @@ -556,6 +557,7 @@ def convert_module_to_trt_engine( enabled_precisions = {dtype._from(e) for e in enabled_precisions} compilation_options = { + "assume_dynamic_shape_support": assume_dynamic_shape_support, "enabled_precisions": enabled_precisions, "debug": debug, "workspace_size": workspace_size, @@ -595,6 +597,10 @@ def convert_module_to_trt_engine( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) + + # Assume converters support dynamic shapes and disable validation + CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) + try: interpreter_result = interpret_module_to_result(gm, input_list, settings) except UnsupportedOperatorException: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index a57c20dcd6..a621efcc16 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -6,7 +6,7 @@ DEBUG = False DEVICE = None DISABLE_TF32 = False -DISABLE_DYNAMIC_CONVERTER_CHECKS = False +ASSUME_DYNAMIC_SHAPE_SUPPORT = False DLA_LOCAL_DRAM_SIZE = 1073741824 DLA_GLOBAL_DRAM_SIZE = 536870912 DLA_SRAM_SIZE = 1048576 diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index f3b6d89dcd..e13d4b5e22 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -5,8 +5,8 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( + ASSUME_DYNAMIC_SHAPE_SUPPORT, DEBUG, - DISABLE_DYNAMIC_CONVERTER_CHECKS, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -58,7 +58,7 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path - disable_dynamic_converter_checks (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False + assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights refit (bool): Whether to build a refittable engine @@ -89,7 +89,7 @@ class CompilationSettings: device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION disable_tf32: bool = DISABLE_TF32 - disable_dynamic_converter_checks: bool = DISABLE_DYNAMIC_CONVERTER_CHECKS + assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT sparse_weights: bool = SPARSE_WEIGHTS refit: bool = REFIT engine_capability: EngineCapability = field( diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 7af25ae552..8069b9b9c0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -280,7 +280,7 @@ def __init__( ], registry_names: Optional[Sequence[str]] = None, registry_calling_conventions: Optional[Sequence[CallingConvention]] = None, - disable_dynamic_converter_checks: bool = False, + assume_dynamic_shape_support: bool = False, ): # Copy reference to each dictionary object into attribute list self.registries = list(registries) @@ -302,11 +302,11 @@ def __init__( ] self.disallowed_targets: Collection[Target] = set() - self.disable_dynamic_converter_checks = disable_dynamic_converter_checks + self.assume_dynamic_shape_support = assume_dynamic_shape_support self.validate_invariants() - def disable_dynamic_checks(self, disable_dynamic_converter_checks: bool) -> None: - self.disable_dynamic_converter_checks = disable_dynamic_converter_checks + def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None: + self.assume_dynamic_shape_support = assume_dynamic_shape_support def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: self.disallowed_targets = torch_executed_ops @@ -415,7 +415,7 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: if candidate.capability_validator(node) and ( - self.disable_dynamic_converter_checks + self.assume_dynamic_shape_support or ( has_dynamic_shapes(node) and candidate.supports_dynamic_shapes diff --git a/tests/py/dynamo/partitioning/test_dynamic_partitioning.py b/tests/py/dynamo/partitioning/test_dynamic_partitioning.py index 08283ebffb..9b18c1fc2f 100644 --- a/tests/py/dynamo/partitioning/test_dynamic_partitioning.py +++ b/tests/py/dynamo/partitioning/test_dynamic_partitioning.py @@ -59,7 +59,7 @@ def forward(self, x): f"Number of PyTorch segments should be 1 but got {pyt_segments}", ) - def test_disable_dynamic_converter_checks(self): + def test_assume_dynamic_shape_support_converters(self): class Clamp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -82,7 +82,7 @@ def forward(self, x): ) ], dryrun=True, - disable_dynamic_converter_checks=True, + assume_dynamic_shape_support=True, min_block_size=1, ) From 382ea09d998179195313295c5b8a271bcf93efeb Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 16 May 2024 11:22:39 -0700 Subject: [PATCH 10/10] chore: updates --- .../dynamo/conversion/_ConverterRegistry.py | 37 ++++++++++++------- .../dynamo/conversion/aten_ops_converters.py | 8 +++- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 8069b9b9c0..1afb9749c6 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -91,6 +91,11 @@ class ConverterSupport: DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {} +def has_static_shapes(node: torch.fx.Node) -> bool: + """Returns True if a node has static args, kwargs, or outputs""" + return not _has_dynamic_shapes(node=node) + + def has_dynamic_shapes(node: torch.fx.Node) -> bool: """Returns True if a node has dynamic args, kwargs, or outputs""" return _has_dynamic_shapes(node=node) @@ -105,6 +110,18 @@ def has_dynamic_shapes_in_args( ) +def has_static_shapes_in_args( + arg_positions_to_check: Optional[List[int]] = None, +) -> Callable[[torch.fx.Node], bool]: + """Returns True if a node has static inputs in node.args at specified positions""" + _has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes( + node, arg_positions_to_check + ) + return functools.partial( + _has_static_shapes, arg_positions_to_check=arg_positions_to_check + ) + + def _has_dynamic_shapes( node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None ) -> bool: @@ -414,22 +431,16 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: + # We enable the converter under 4 conditions + # 1) capability validator is True + # 2) Assume dynamic_shape support is True + # 3) Node only has static shaped inputs + # 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True if candidate.capability_validator(node) and ( self.assume_dynamic_shape_support - or ( - has_dynamic_shapes(node) - and candidate.supports_dynamic_shapes - ) + or not has_dynamic_shapes(node) + or candidate.supports_dynamic_shapes ): - # If node has dynamic inputs and the converter supports dynamic shapes, it is enabled - return ( - candidate.converter_implementation, - calling_convention, - ) - elif candidate.capability_validator( - node - ) and not has_dynamic_shapes(node): - # For static shapes all converters are turned on based on capability_validator check return ( candidate.converter_implementation, calling_convention, diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b31617247a..23483986a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -12,6 +12,7 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, + has_static_shapes_in_args, ) from torch_tensorrt.dynamo.conversion.converter_utils import ( enforce_tensor_types, @@ -627,11 +628,14 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, + torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]) +) +@dynamo_tensorrt_converter( + torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]) ) -@dynamo_tensorrt_converter(torch.ops.aten.split.sizes) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, + capability_validator=has_static_shapes_in_args([1]), ) def aten_ops_split( ctx: ConversionContext,