Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add validators for dynamic shapes in converter registration #2796

Merged
merged 13 commits into from
May 16, 2024
87 changes: 86 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
Expand All @@ -17,6 +18,8 @@
cast,
)

import torch
from torch import SymBool, SymFloat, SymInt
from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -75,22 +78,91 @@ class ConverterSupport:
capability_validator: Function which takes in a Node and returns a bool indicating
whether that node can be supported by its companion converter. Note that
this function must not modify the node or its graph
supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs.
"""

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
supports_dynamic_shapes: bool = False


# Dictionary representing Dynamo aten-only converters
# Each converter maps to a sequence of at least one ConverterSupport object(s)
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}


def has_dynamic_shapes(node: torch.fx.Node) -> bool:
peri044 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns True if a node has dynamic args, kwargs, or outputs"""
return _has_dynamic_shapes(node=node)


def has_dynamic_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns True if a node has dynamic inputs in node.args at specified positions"""
return functools.partial(
_has_dynamic_shapes, arg_positions_to_check=arg_positions_to_check
)


def _has_dynamic_shapes(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
_has_symbolic_sizes_strides, is_shape_dynamic = False, False
if "val" in subnode.meta:
_has_symbolic_sizes_strides = getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
)

shape = subnode.meta["val"].size()
is_shape_dynamic = any(
isinstance(dim, (SymFloat, SymInt, SymBool)) for dim in shape
)

return _has_symbolic_sizes_strides or is_shape_dynamic

# Check node value itself
if arg_positions_to_check is None and _is_subnode_dynamic(node):
return True

# Check node arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
):
return True
# Check specific arg positions if the caller has specified positions to check
elif arg_positions_to_check is not None and any(
_is_subnode_dynamic(node.args[i])
for i in arg_positions_to_check
if isinstance(node.args[i], torch.fx.Node)
):
return True

# Check node keyword arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(kwarg)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
return True

return False


def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
"""Decorator for Dynamo TensorRT Converter

Expand All @@ -116,14 +188,18 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat

# If no capability_validator function is specified, use the default function - always return true
if capability_validator is None:
converter_support = ConverterSupport(converter_implementation=converter)
converter_support = ConverterSupport(
converter_implementation=converter,
supports_dynamic_shapes=supports_dynamic_shapes,
)
else:
assert callable(
capability_validator
), "Argument checking function must be callable"
converter_support = ConverterSupport(
converter_implementation=converter,
capability_validator=capability_validator,
supports_dynamic_shapes=supports_dynamic_shapes,
)

# OpOverloadPackets are only valid if they have a single overload, or
Expand Down Expand Up @@ -323,6 +399,15 @@ def __getitem__(

if isinstance(converters, (list, tuple)):
for candidate in converters:
# If there are dynamic inputs but the converter doesn't support it explicitly, throw a warning.
if (
peri044 marked this conversation as resolved.
Show resolved Hide resolved
not candidate.supports_dynamic_shapes
and has_dynamic_shapes(node)
):
logger.warning(
f"The converter for node {node.target} received dynamic shaped inputs although it was designed for static inputs. This shouldn't likely cause issues unless there are some dimensions which are dynamic (excluding the batch). If you encounter any issues, please post at https://github.com/pytorch/TensorRT/issues"
)

if candidate.capability_validator(node):
return (
candidate.converter_implementation,
Expand Down
14 changes: 6 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import (
dynamic_unsupported_with_args,
enforce_tensor_types,
is_only_operator_on_placeholder,
)
Expand Down Expand Up @@ -358,7 +357,7 @@ def aten_ops_grid(
)


@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
@dynamo_tensorrt_converter(torch.ops.aten.relu.default, supports_dynamic_shapes=True)
def aten_ops_relu(
ctx: ConversionContext,
target: Target,
Expand Down Expand Up @@ -645,14 +644,11 @@ def aten_ops_softmax(


@dynamo_tensorrt_converter(
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
)
@dynamo_tensorrt_converter(
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
peri044 marked this conversation as resolved.
Show resolved Hide resolved
torch.ops.aten.split.Tensor,
)
@dynamo_tensorrt_converter(torch.ops.aten.split.sizes)
@dynamo_tensorrt_converter(
torch.ops.aten.split_with_sizes.default,
capability_validator=dynamic_unsupported_with_args([1]),
)
def aten_ops_split(
ctx: ConversionContext,
Expand Down Expand Up @@ -2080,7 +2076,9 @@ def conv_param_validator(conv_node: Node) -> bool:


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default, capability_validator=conv_param_validator
torch.ops.aten.convolution.default,
capability_validator=conv_param_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand Down
57 changes: 0 additions & 57 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import tensorrt as trt
import torch
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Argument, Target
from torch_tensorrt import _enums
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand Down Expand Up @@ -58,62 +57,6 @@ def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
)


def dynamic_unsupported(node: torch.fx.Node) -> bool:
"""Validates that a node has no dynamic args, kwargs, or outputs"""
return _dynamic_unsupported(node=node)


def dynamic_unsupported_with_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
"""Returns a validator that a node has no dynamic args at specific positions"""
return functools.partial(
_dynamic_unsupported, arg_positions_to_check=arg_positions_to_check
)


def _dynamic_unsupported(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
node, torch.fx.Node
), "Inputs to validator functions must be FX Nodes"

def _is_subnode_dynamic(subnode: torch.fx.Node) -> bool:
"""Checks if a node itself has Dynamic properties"""
return getattr(
subnode.meta["val"], "_has_symbolic_sizes_strides", False
) or isinstance(subnode.meta["val"], (SymFloat, SymInt, SymBool))

# Check node value itself
if arg_positions_to_check is None and _is_subnode_dynamic(node):
return False

# Check node arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(arg) for arg in node.args if isinstance(arg, torch.fx.Node)
):
return False
# Check specific arg positions if the caller has specified positions to check
elif arg_positions_to_check is not None and any(
_is_subnode_dynamic(node.args[i])
for i in arg_positions_to_check
if isinstance(node.args[i], torch.fx.Node)
):
return False

# Check node keyword arguments individually
if arg_positions_to_check is None and any(
_is_subnode_dynamic(kwarg)
for kwarg in node.kwargs.values()
if isinstance(kwarg, torch.fx.Node)
):
return False

return True


def cast_trt_tensor(
ctx: ConversionContext,
input_val: TRTTensor,
Expand Down
Loading