Skip to content

Commit

Permalink
feat: Add validators for dynamic shapes in converter registration (#2796
Browse files Browse the repository at this point in the history
)
  • Loading branch information
peri044 committed May 17, 2024
1 parent db24b3b commit b2dab1b
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 88 deletions.
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def compile(
*,
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
disable_tf32: bool = _defaults.DISABLE_TF32,
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
Expand Down Expand Up @@ -106,6 +107,7 @@ def compile(
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
Expand Down Expand Up @@ -189,6 +191,7 @@ def compile(
),
"debug": debug,
"device": device,
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": (
Expand Down Expand Up @@ -239,6 +242,9 @@ def compile_module(
"""
dryrun_tracker = DryRunTracker()

# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)

# Set torch-executed ops
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)

Expand Down Expand Up @@ -443,6 +449,7 @@ def convert_module_to_trt_engine(
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
debug: bool = _defaults.DEBUG,
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
workspace_size: int = _defaults.WORKSPACE_SIZE,
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
torch_executed_ops: Optional[Set[str]] = None,
Expand Down Expand Up @@ -550,6 +557,7 @@ def convert_module_to_trt_engine(
enabled_precisions = {dtype._from(e) for e in enabled_precisions}

compilation_options = {
"assume_dynamic_shape_support": assume_dynamic_shape_support,
"enabled_precisions": enabled_precisions,
"debug": debug,
"workspace_size": workspace_size,
Expand Down Expand Up @@ -589,6 +597,10 @@ def convert_module_to_trt_engine(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)

try:
interpreter_result = interpret_module_to_result(gm, input_list, settings)
except UnsupportedOperatorException:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEBUG = False
DEVICE = None
DISABLE_TF32 = False
ASSUME_DYNAMIC_SHAPE_SUPPORT = False
DLA_LOCAL_DRAM_SIZE = 1073741824
DLA_GLOBAL_DRAM_SIZE = 536870912
DLA_SRAM_SIZE = 1048576
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
Expand Down Expand Up @@ -57,6 +58,7 @@ class CompilationSettings:
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
refit (bool): Whether to build a refittable engine
Expand Down Expand Up @@ -87,6 +89,7 @@ class CompilationSettings:
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
disable_tf32: bool = DISABLE_TF32
assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT
sparse_weights: bool = SPARSE_WEIGHTS
refit: bool = REFIT
engine_capability: EngineCapability = field(
Expand Down
130 changes: 124 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
Expand All @@ -17,13 +18,14 @@
cast,
)

import tensorrt as trt
import torch
from torch import SymBool, SymFloat, SymInt
from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS

import tensorrt as trt

logger = logging.getLogger(__name__)

LegacyConverterImplSignature = Callable[
Expand Down Expand Up @@ -76,22 +78,119 @@ class ConverterSupport:
capability_validator: Function which takes in a Node and returns a bool indicating
whether that node can be supported by its companion converter. Note that
this function must not modify the node or its graph
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
"""

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
supports_dynamic_shapes: bool = False


# Dictionary representing Dynamo aten-only converters
# Each converter maps to a sequence of at least one ConverterSupport object(s)
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}


def has_static_shapes(node: torch.fx.Node) -> bool:
"""Returns True if a node has static args, kwargs, or outputs"""
return not _has_dynamic_shapes(node=node)


def has_dynamic_shapes(node: torch.fx.Node) -> bool:
"""Returns True if a node has dynamic args, kwargs, or outputs"""
return _has_dynamic_shapes(node=node)


def has_dynamic_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns True if a node has dynamic inputs in node.args at specified positions"""
return functools.partial(
_has_dynamic_shapes, arg_positions_to_check=arg_positions_to_check
)


def has_static_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns True if a node has static inputs in node.args at specified positions"""
_has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes(
node, arg_positions_to_check
)
return functools.partial(
_has_static_shapes, arg_positions_to_check=arg_positions_to_check
)


def _has_dynamic_shapes(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
_has_symbolic_sizes_strides, is_shape_dynamic = False, False
if "val" in subnode.meta:
_has_symbolic_sizes_strides = getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
)
meta_val = subnode.meta["val"]
if isinstance(meta_val, (list, tuple)):
for val in meta_val:
shape = val.size()
if any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
):
is_shape_dynamic = True
break
elif isinstance(meta_val, (SymFloat, SymInt, SymBool)):
is_shape_dynamic = True
else:
shape = subnode.meta["val"].size()
is_shape_dynamic = any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
)

return _has_symbolic_sizes_strides or is_shape_dynamic

# Check node value itself
if arg_positions_to_check is None and _is_subnode_dynamic(node):
return True

# Check node arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
):
return True
# Check specific arg positions if the caller has specified positions to check
elif arg_positions_to_check is not None and any(
_is_subnode_dynamic(node.args[i])
for i in arg_positions_to_check
if isinstance(node.args[i], torch.fx.Node)
):
return True

# Check node keyword arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(kwarg)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
return True

return False


def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
"""Decorator for Dynamo TensorRT Converter
Expand All @@ -117,14 +216,18 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat

# If no capability_validator function is specified, use the default function - always return true
if capability_validator is None:
converter_support = ConverterSupport(converter_implementation=converter)
converter_support = ConverterSupport(
converter_implementation=converter,
supports_dynamic_shapes=supports_dynamic_shapes,
)
else:
assert callable(
capability_validator
), "Argument checking function must be callable"
converter_support = ConverterSupport(
converter_implementation=converter,
capability_validator=capability_validator,
supports_dynamic_shapes=supports_dynamic_shapes,
)

# OpOverloadPackets are only valid if they have a single overload, or
Expand Down Expand Up @@ -194,6 +297,7 @@ def __init__(
],
registry_names: Optional[Sequence[str]] = None,
registry_calling_conventions: Optional[Sequence[CallingConvention]] = None,
assume_dynamic_shape_support: bool = False,
):
# Copy reference to each dictionary object into attribute list
self.registries = list(registries)
Expand All @@ -215,9 +319,12 @@ def __init__(
]

self.disallowed_targets: Collection[Target] = set()

self.assume_dynamic_shape_support = assume_dynamic_shape_support
self.validate_invariants()

def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None:
self.assume_dynamic_shape_support = assume_dynamic_shape_support

def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
self.disallowed_targets = torch_executed_ops

Expand Down Expand Up @@ -324,13 +431,24 @@ def __getitem__(

if isinstance(converters, (list, tuple)):
for candidate in converters:
if candidate.capability_validator(node):
# We enable the converter under 4 conditions
# 1) capability validator is True
# 2) Assume dynamic_shape support is True
# 3) Node only has static shaped inputs
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
if candidate.capability_validator(node) and (
self.assume_dynamic_shape_support
or not has_dynamic_shapes(node)
or candidate.supports_dynamic_shapes
):
return (
candidate.converter_implementation,
calling_convention,
)
else:
return converters, calling_convention
# Assuming FX converters don't have dynamic shapes supported
if not has_dynamic_shapes(node):
return converters, calling_convention

raise KeyError(
f"None of the converter registries have a validated entry for {key}, with node {node}"
Expand Down
Loading

0 comments on commit b2dab1b

Please sign in to comment.