diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 32b0ca65d7..c3cca50f65 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -47,6 +47,7 @@ def compile( *, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, disable_tf32: bool = _defaults.DISABLE_TF32, + assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] @@ -106,6 +107,7 @@ def compile( device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels refit (bool): Enable refitting @@ -189,6 +191,7 @@ def compile( ), "debug": debug, "device": device, + "assume_dynamic_shape_support": assume_dynamic_shape_support, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": ( @@ -239,6 +242,9 @@ def compile_module( """ dryrun_tracker = DryRunTracker() + # Assume converters support dynamic shapes and disable validation + CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) + # Set torch-executed ops CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) @@ -443,6 +449,7 @@ def convert_module_to_trt_engine( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, debug: bool = _defaults.DEBUG, + assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, workspace_size: int = _defaults.WORKSPACE_SIZE, min_block_size: int = _defaults.MIN_BLOCK_SIZE, torch_executed_ops: Optional[Set[str]] = None, @@ -550,6 +557,7 @@ def convert_module_to_trt_engine( enabled_precisions = {dtype._from(e) for e in enabled_precisions} compilation_options = { + "assume_dynamic_shape_support": assume_dynamic_shape_support, "enabled_precisions": enabled_precisions, "debug": debug, "workspace_size": workspace_size, @@ -589,6 +597,10 @@ def convert_module_to_trt_engine( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) + + # Assume converters support dynamic shapes and disable validation + CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) + try: interpreter_result = interpret_module_to_result(gm, input_list, settings) except UnsupportedOperatorException: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 97430137c0..a621efcc16 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -6,6 +6,7 @@ DEBUG = False DEVICE = None DISABLE_TF32 = False +ASSUME_DYNAMIC_SHAPE_SUPPORT = False DLA_LOCAL_DRAM_SIZE = 1073741824 DLA_GLOBAL_DRAM_SIZE = 536870912 DLA_SRAM_SIZE = 1048576 diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 9592bc1fd5..e13d4b5e22 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -5,6 +5,7 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( + ASSUME_DYNAMIC_SHAPE_SUPPORT, DEBUG, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -57,6 +58,7 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path + assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights refit (bool): Whether to build a refittable engine @@ -87,6 +89,7 @@ class CompilationSettings: device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION disable_tf32: bool = DISABLE_TF32 + assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT sparse_weights: bool = SPARSE_WEIGHTS refit: bool = REFIT engine_capability: EngineCapability = field( diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 6a8a7e3d41..1afb9749c6 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import logging from dataclasses import dataclass, field from enum import Enum, auto @@ -17,13 +18,14 @@ cast, ) +import tensorrt as trt +import torch +from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS -import tensorrt as trt - logger = logging.getLogger(__name__) LegacyConverterImplSignature = Callable[ @@ -76,10 +78,12 @@ class ConverterSupport: capability_validator: Function which takes in a Node and returns a bool indicating whether that node can be supported by its companion converter. Note that this function must not modify the node or its graph + supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs. """ converter_implementation: ConverterImplSignature capability_validator: Callable[[Node], bool] = field(default=lambda node: True) + supports_dynamic_shapes: bool = False # Dictionary representing Dynamo aten-only converters @@ -87,11 +91,106 @@ class ConverterSupport: DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {} +def has_static_shapes(node: torch.fx.Node) -> bool: + """Returns True if a node has static args, kwargs, or outputs""" + return not _has_dynamic_shapes(node=node) + + +def has_dynamic_shapes(node: torch.fx.Node) -> bool: + """Returns True if a node has dynamic args, kwargs, or outputs""" + return _has_dynamic_shapes(node=node) + + +def has_dynamic_shapes_in_args( + arg_positions_to_check: Optional[List[int]] = None, +) -> Callable[[torch.fx.Node], bool]: + """Returns True if a node has dynamic inputs in node.args at specified positions""" + return functools.partial( + _has_dynamic_shapes, arg_positions_to_check=arg_positions_to_check + ) + + +def has_static_shapes_in_args( + arg_positions_to_check: Optional[List[int]] = None, +) -> Callable[[torch.fx.Node], bool]: + """Returns True if a node has static inputs in node.args at specified positions""" + _has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes( + node, arg_positions_to_check + ) + return functools.partial( + _has_static_shapes, arg_positions_to_check=arg_positions_to_check + ) + + +def _has_dynamic_shapes( + node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None +) -> bool: + # Validate that none of the inputs to the node have Dynamic shapes + assert isinstance( + node, torch.fx.Node + ), "Inputs to validator functions must be FX Nodes" + + def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: + """Checks if a node itself has Dynamic properties""" + _has_symbolic_sizes_strides, is_shape_dynamic = False, False + if "val" in subnode.meta: + _has_symbolic_sizes_strides = getattr( + subnode.meta["val"], "_has_symbolic_sizes_strides", False + ) + meta_val = subnode.meta["val"] + if isinstance(meta_val, (list, tuple)): + for val in meta_val: + shape = val.size() + if any( + isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape + ): + is_shape_dynamic = True + break + elif isinstance(meta_val, (SymFloat, SymInt, SymBool)): + is_shape_dynamic = True + else: + shape = subnode.meta["val"].size() + is_shape_dynamic = any( + isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape + ) + + return _has_symbolic_sizes_strides or is_shape_dynamic + + # Check node value itself + if arg_positions_to_check is None and _is_subnode_dynamic(node): + return True + + # Check node arguments individually + if arg_positions_to_check is None and any( + _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) + ): + return True + # Check specific arg positions if the caller has specified positions to check + elif arg_positions_to_check is not None and any( + _is_subnode_dynamic(node.args[i]) + for i in arg_positions_to_check + if isinstance(node.args[i], torch.fx.Node) + ): + return True + + # Check node keyword arguments individually + if arg_positions_to_check is None and any( + _is_subnode_dynamic(kwarg) + for kwarg in node.kwargs.values() + if isinstance(kwarg, torch.fx.Node) + ): + return True + + return False + + def dynamo_tensorrt_converter( key: Target, + *, enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, ) -> Callable[[ConverterImplSignature], ConverterImplSignature]: """Decorator for Dynamo TensorRT Converter @@ -117,7 +216,10 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat # If no capability_validator function is specified, use the default function - always return true if capability_validator is None: - converter_support = ConverterSupport(converter_implementation=converter) + converter_support = ConverterSupport( + converter_implementation=converter, + supports_dynamic_shapes=supports_dynamic_shapes, + ) else: assert callable( capability_validator @@ -125,6 +227,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat converter_support = ConverterSupport( converter_implementation=converter, capability_validator=capability_validator, + supports_dynamic_shapes=supports_dynamic_shapes, ) # OpOverloadPackets are only valid if they have a single overload, or @@ -194,6 +297,7 @@ def __init__( ], registry_names: Optional[Sequence[str]] = None, registry_calling_conventions: Optional[Sequence[CallingConvention]] = None, + assume_dynamic_shape_support: bool = False, ): # Copy reference to each dictionary object into attribute list self.registries = list(registries) @@ -215,9 +319,12 @@ def __init__( ] self.disallowed_targets: Collection[Target] = set() - + self.assume_dynamic_shape_support = assume_dynamic_shape_support self.validate_invariants() + def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None: + self.assume_dynamic_shape_support = assume_dynamic_shape_support + def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: self.disallowed_targets = torch_executed_ops @@ -324,13 +431,24 @@ def __getitem__( if isinstance(converters, (list, tuple)): for candidate in converters: - if candidate.capability_validator(node): + # We enable the converter under 4 conditions + # 1) capability validator is True + # 2) Assume dynamic_shape support is True + # 3) Node only has static shaped inputs + # 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True + if candidate.capability_validator(node) and ( + self.assume_dynamic_shape_support + or not has_dynamic_shapes(node) + or candidate.supports_dynamic_shapes + ): return ( candidate.converter_implementation, calling_convention, ) else: - return converters, calling_convention + # Assuming FX converters don't have dynamic shapes supported + if not has_dynamic_shapes(node): + return converters, calling_convention raise KeyError( f"None of the converter registries have a validated entry for {key}, with node {node}" diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8163814491..23483986a5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -12,9 +12,9 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( dynamo_tensorrt_converter, + has_static_shapes_in_args, ) from torch_tensorrt.dynamo.conversion.converter_utils import ( - dynamic_unsupported_with_args, enforce_tensor_types, is_only_operator_on_placeholder, ) @@ -57,10 +57,14 @@ def one_user_validator(node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.native_batch_norm.default, capability_validator=one_user_validator + torch.ops.aten.native_batch_norm.default, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, +) +@dynamo_tensorrt_converter( + torch.ops.aten.batch_norm.default, supports_dynamic_shapes=True ) -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm.default) -@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -94,6 +98,7 @@ def aten_ops_batch_norm( @dynamo_tensorrt_converter( torch.ops.aten._native_batch_norm_legit_no_training.default, capability_validator=one_user_validator, + supports_dynamic_shapes=True, ) def aten_ops_batch_norm_legit_no_training( ctx: ConversionContext, @@ -336,7 +341,7 @@ def aten_ops_grid( ) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.relu.default, supports_dynamic_shapes=True) def aten_ops_relu( ctx: ConversionContext, target: Target, @@ -353,7 +358,7 @@ def aten_ops_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) +@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default, supports_dynamic_shapes=True) def aten_ops_sigmoid( ctx: ConversionContext, target: Target, @@ -375,7 +380,7 @@ def aten_ops_sigmoid( 0: (TRTTensor,), } ) -@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int) +@dynamo_tensorrt_converter(torch.ops.aten.sym_size.int, supports_dynamic_shapes=True) def aten_ops_symsize_int( ctx: ConversionContext, target: Target, @@ -623,14 +628,14 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) + torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]) ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) + torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]) ) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, - capability_validator=dynamic_unsupported_with_args([1]), + capability_validator=has_static_shapes_in_args([1]), ) def aten_ops_split( ctx: ConversionContext, @@ -1094,8 +1099,8 @@ def aten_ops_min( ) -@dynamo_tensorrt_converter(torch.ops.aten.mean.default) -@dynamo_tensorrt_converter(torch.ops.aten.mean.dim) +@dynamo_tensorrt_converter(torch.ops.aten.mean.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.mean.dim, supports_dynamic_shapes=True) def aten_ops_mean( ctx: ConversionContext, target: Target, @@ -1199,7 +1204,7 @@ def aten_ops_recip( ) -@dynamo_tensorrt_converter(torch.ops.aten.abs.default) +@dynamo_tensorrt_converter(torch.ops.aten.abs.default, supports_dynamic_shapes=True) def aten_ops_abs( ctx: ConversionContext, target: Target, @@ -1546,8 +1551,8 @@ def aten_ops_isnan( ) -@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True) def aten_ops_add( ctx: ConversionContext, target: Target, @@ -2117,7 +2122,9 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( - torch.ops.aten.convolution.default, capability_validator=conv_param_validator + torch.ops.aten.convolution.default, + capability_validator=conv_param_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { @@ -2305,13 +2312,19 @@ def max_pool_param_validator(pool_node: Node) -> bool: # Note: MaxPool1d uses max_pool2d as it converts to 2D first. @dynamo_tensorrt_converter( - torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator + torch.ops.aten.max_pool1d.default, + capability_validator=max_pool_param_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator + torch.ops.aten.max_pool2d.default, + capability_validator=max_pool_param_validator, + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator + torch.ops.aten.max_pool3d.default, + capability_validator=max_pool_param_validator, + supports_dynamic_shapes=True, ) def aten_ops_max_pool( ctx: ConversionContext, @@ -2356,8 +2369,8 @@ def tensorrt_scaled_dot_product_attention( ) -@dynamo_tensorrt_converter(torch.ops.aten.reshape.default) -@dynamo_tensorrt_converter(torch.ops.aten.view.default) +@dynamo_tensorrt_converter(torch.ops.aten.reshape.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.view.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -2466,7 +2479,7 @@ def aten_ops_argmin( ) -@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) +@dynamo_tensorrt_converter(torch.ops.aten.addmm.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index a263440128..135309443e 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np +import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl -from torch import SymBool, SymFloat, SymInt from torch.fx.node import Argument, Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -18,8 +18,6 @@ from torch_tensorrt.fx.converters.converter_utils import get_axes_for_reduce_op from torch_tensorrt.fx.types import TRTDataType, TRTTensor -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -60,62 +58,6 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: ) -def dynamic_unsupported(node: torch.fx.Node) -> bool: - """Validates that a node has no dynamic args, kwargs, or outputs""" - return _dynamic_unsupported(node=node) - - -def dynamic_unsupported_with_args( - arg_positions_to_check: Optional[List[int]] = None, -) -> Callable[[torch.fx.Node], bool]: - """Returns a validator that a node has no dynamic args at specific positions""" - return functools.partial( - _dynamic_unsupported, arg_positions_to_check=arg_positions_to_check - ) - - -def _dynamic_unsupported( - node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None -) -> bool: - # Validate that none of the inputs to the node have Dynamic shapes - assert isinstance( - node, torch.fx.Node - ), "Inputs to validator functions must be FX Nodes" - - def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool: - """Checks if a node itself has Dynamic properties""" - return getattr( - subnode.meta["val"], "_has_symbolic_sizes_strides", False - ) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool)) - - # Check node value itself - if arg_positions_to_check is None and _is_subnode_dynamic(node): - return False - - # Check node arguments individually - if arg_positions_to_check is None and any( - _is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node) - ): - return False - # Check specific arg positions if the caller has specified positions to check - elif arg_positions_to_check is not None and any( - _is_subnode_dynamic(node.args[i]) - for i in arg_positions_to_check - if isinstance(node.args[i], torch.fx.Node) - ): - return False - - # Check node keyword arguments individually - if arg_positions_to_check is None and any( - _is_subnode_dynamic(kwarg) - for kwarg in node.kwargs.values() - if isinstance(kwarg, torch.fx.Node) - ): - return False - - return True - - def cast_trt_tensor( ctx: ConversionContext, input_val: TRTTensor, diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index f83e0e5008..ea2f1c4d89 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -23,7 +23,11 @@ def getitem_validator(getitem_node: Node) -> bool: # TODO: Subsequent evaluators should be registered here with their own validators -@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +@dynamo_tensorrt_converter( + operator.getitem, + capability_validator=getitem_validator, + supports_dynamic_shapes=True, +) @dynamo_tensorrt_converter(torch.ops.aten.detach.default) def generic_evaluator( ctx: ConversionContext, diff --git a/tests/py/dynamo/partitioning/test_dynamic_partitioning.py b/tests/py/dynamo/partitioning/test_dynamic_partitioning.py new file mode 100644 index 0000000000..9b18c1fc2f --- /dev/null +++ b/tests/py/dynamo/partitioning/test_dynamic_partitioning.py @@ -0,0 +1,109 @@ +from copy import deepcopy + +import numpy as np +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning + +from ..testing_utilities import lower_graph_testing + +# This testcase assumes that torch.ops.aten.clamp.default converter doesn't support +# dynamic shapes. One should remove this testcase when the support is added. +# This testcase tests if the graph is partitioned correctly into a TRT segment +# and a Pytorch segment when the torch.ops.aten.clamp.default converter gets disabled +# due to lack of dynamic shape support. + + +class TestDynamicPartitioning(TestCase): + def test_partition_dynamic_clamp(self): + class Clamp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return torch.ops.aten.clamp.default(x, min=2.5, max=6.5) + + model = Clamp().eval().cuda() + trt_model = torch_tensorrt.compile( + model, + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 3, 8, 8), + opt_shape=(4, 3, 8, 8), + max_shape=(8, 3, 8, 8), + dtype=torch.float32, + name="x", + ) + ], + dryrun=True, + min_block_size=1, + ) + trt_segments, pyt_segments = 0, 0 + for submod in list(trt_model.named_children()): + if "_run_on_acc" in submod[0]: + trt_segments += 1 + elif "_run_on_gpu" in submod[0]: + pyt_segments += 1 + + self.assertEquals( + trt_segments, + 1, + f"Number of TRT segments should be 1 but got {trt_segments}", + ) + self.assertEquals( + pyt_segments, + 1, + f"Number of PyTorch segments should be 1 but got {pyt_segments}", + ) + + def test_assume_dynamic_shape_support_converters(self): + class Clamp(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.relu(x) + return torch.ops.aten.clamp.default(x, min=2.5, max=6.5) + + model = Clamp().eval().cuda() + trt_model = torch_tensorrt.compile( + model, + inputs=[ + torch_tensorrt.Input( + min_shape=(1, 3, 8, 8), + opt_shape=(4, 3, 8, 8), + max_shape=(8, 3, 8, 8), + dtype=torch.float32, + name="x", + ) + ], + dryrun=True, + assume_dynamic_shape_support=True, + min_block_size=1, + ) + + trt_segments, pyt_segments = 0, 0 + for submod in list(trt_model.named_children()): + if "_run_on_acc" in submod[0]: + trt_segments += 1 + elif "_run_on_gpu" in submod[0]: + pyt_segments += 1 + + self.assertEquals( + trt_segments, + 1, + f"Number of TRT segments should be 2 but got {trt_segments}", + ) + self.assertEquals( + pyt_segments, + 0, + f"Number of PyTorch segments should be 0 but got {pyt_segments}", + ) + + +if __name__ == "__main__": + run_tests()