Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validators for dynamic shapes in converter registration #2796

Merged
merged 13 commits into from
May 16, 2024
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compile(
*,
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
disable_tf32: bool = _defaults.DISABLE_TF32,
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
Expand Down Expand Up @@ -106,6 +107,7 @@ def compile(
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)

disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
Expand Down Expand Up @@ -189,6 +191,7 @@ def compile(
),
"debug": debug,
"device": device,
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": (
Expand Down Expand Up @@ -239,6 +242,9 @@ def compile_module(
"""
dryrun_tracker = DryRunTracker()

# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)

# Set torch-executed ops
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)

Expand Down Expand Up @@ -443,6 +449,7 @@ def convert_module_to_trt_engine(
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
debug: bool = _defaults.DEBUG,
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
workspace_size: int = _defaults.WORKSPACE_SIZE,
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Set[str]] = None,
Expand Down Expand Up @@ -550,6 +557,7 @@ def convert_module_to_trt_engine(
enabled_precisions = {dtype._from(e) for e in enabled_precisions}

compilation_options = {
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"enabled_precisions": enabled_precisions,
"debug": debug,
"workspace_size": workspace_size,
Expand Down Expand Up @@ -589,6 +597,10 @@ def convert_module_to_trt_engine(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)

try:
interpreter_result = interpret_module_to_result(gm, input_list, settings)
except UnsupportedOperatorException:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEBUG = False
DEVICE = None
DISABLE_TF32 = False
ASSUME_DYNAMIC_SHAPE_SUPPORT = False
peri044 marked this conversation as resolved.
Show resolved Hide resolved
DLA_LOCAL_DRAM_SIZE = 1073741824
DLA_GLOBAL_DRAM_SIZE = 536870912
DLA_SRAM_SIZE = 1048576
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
Expand Down Expand Up @@ -57,6 +58,7 @@ class CompilationSettings:
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
refit (bool): Whether to build a refittable engine
Expand Down Expand Up @@ -87,6 +89,7 @@ class CompilationSettings:
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
disable_tf32: bool = DISABLE_TF32
assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT
sparse_weights: bool = SPARSE_WEIGHTS
refit: bool = REFIT
engine_capability: EngineCapability = field(
Expand Down
130 changes: 124 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
Expand All @@ -17,13 +18,14 @@
cast,
)

import tensorrt as trt
import torch
from torch import SymBool, SymFloat, SymInt
from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS

import tensorrt as trt

logger = logging.getLogger(__name__)

LegacyConverterImplSignature = Callable[
Expand Down Expand Up @@ -76,22 +78,119 @@ class ConverterSupport:
capability_validator: Function which takes in a Node and returns a bool indicating
whether that node can be supported by its companion converter. Note that
this function must not modify the node or its graph
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
"""

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
supports_dynamic_shapes: bool = False


# Dictionary representing Dynamo aten-only converters
# Each converter maps to a sequence of at least one ConverterSupport object(s)
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}


def has_static_shapes(node: torch.fx.Node) -> bool:
"""Returns True if a node has static args, kwargs, or outputs"""
return not _has_dynamic_shapes(node=node)


def has_dynamic_shapes(node: torch.fx.Node) -> bool:
peri044 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns True if a node has dynamic args, kwargs, or outputs"""
return _has_dynamic_shapes(node=node)


def has_dynamic_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns True if a node has dynamic inputs in node.args at specified positions"""
return functools.partial(
_has_dynamic_shapes, arg_positions_to_check=arg_positions_to_check
)


def has_static_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns True if a node has static inputs in node.args at specified positions"""
_has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes(
node, arg_positions_to_check
)
return functools.partial(
_has_static_shapes, arg_positions_to_check=arg_positions_to_check
)


def _has_dynamic_shapes(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
_has_symbolic_sizes_strides, is_shape_dynamic = False, False
if "val" in subnode.meta:
_has_symbolic_sizes_strides = getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
)
meta_val = subnode.meta["val"]
if isinstance(meta_val, (list, tuple)):
for val in meta_val:
shape = val.size()
if any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
):
is_shape_dynamic = True
break
elif isinstance(meta_val, (SymFloat, SymInt, SymBool)):
is_shape_dynamic = True
else:
shape = subnode.meta["val"].size()
is_shape_dynamic = any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
)

return _has_symbolic_sizes_strides or is_shape_dynamic

# Check node value itself
if arg_positions_to_check is None and _is_subnode_dynamic(node):
return True

# Check node arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
):
return True
# Check specific arg positions if the caller has specified positions to check
elif arg_positions_to_check is not None and any(
_is_subnode_dynamic(node.args[i])
for i in arg_positions_to_check
if isinstance(node.args[i], torch.fx.Node)
):
return True

# Check node keyword arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(kwarg)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
return True

return False


def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
"""Decorator for Dynamo TensorRT Converter

Expand All @@ -117,14 +216,18 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat

# If no capability_validator function is specified, use the default function - always return true
if capability_validator is None:
converter_support = ConverterSupport(converter_implementation=converter)
converter_support = ConverterSupport(
converter_implementation=converter,
supports_dynamic_shapes=supports_dynamic_shapes,
)
else:
assert callable(
capability_validator
), "Argument checking function must be callable"
converter_support = ConverterSupport(
converter_implementation=converter,
capability_validator=capability_validator,
supports_dynamic_shapes=supports_dynamic_shapes,
)

# OpOverloadPackets are only valid if they have a single overload, or
Expand Down Expand Up @@ -194,6 +297,7 @@ def __init__(
],
registry_names: Optional[Sequence[str]] = None,
registry_calling_conventions: Optional[Sequence[CallingConvention]] = None,
assume_dynamic_shape_support: bool = False,
):
# Copy reference to each dictionary object into attribute list
self.registries = list(registries)
Expand All @@ -215,9 +319,12 @@ def __init__(
]

self.disallowed_targets: Collection[Target] = set()

self.assume_dynamic_shape_support = assume_dynamic_shape_support
self.validate_invariants()

def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None:
self.assume_dynamic_shape_support = assume_dynamic_shape_support

def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
self.disallowed_targets = torch_executed_ops

Expand Down Expand Up @@ -324,13 +431,24 @@ def __getitem__(

if isinstance(converters, (list, tuple)):
for candidate in converters:
if candidate.capability_validator(node):
# We enable the converter under 4 conditions
# 1) capability validator is True
# 2) Assume dynamic_shape support is True
# 3) Node only has static shaped inputs
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
if candidate.capability_validator(node) and (
self.assume_dynamic_shape_support
or not has_dynamic_shapes(node)
or candidate.supports_dynamic_shapes
):
return (
candidate.converter_implementation,
calling_convention,
)
else:
return converters, calling_convention
# Assuming FX converters don't have dynamic shapes supported
if not has_dynamic_shapes(node):
return converters, calling_convention

raise KeyError(
f"None of the converter registries have a validated entry for {key}, with node {node}"
Expand Down