Skip to content

Commit

Permalink
chore: rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Aug 17, 2023
2 parents 1fbdd0c + 6a69c6a commit 1ff46b6
Show file tree
Hide file tree
Showing 15 changed files with 357 additions and 77 deletions.
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")


class UnsupportedOperatorException(RuntimeError):
pass


class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
Expand Down Expand Up @@ -301,7 +305,7 @@ def call_module(
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of module of type {submod_type} not currently supported!"
)

Expand All @@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
# TODO: Why is this stateful? We should be able to take in the inputs
converter = CONVERTERS.get(self._cur_node)
if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of function {torch.typename(target)} not currently supported!"
)

Expand All @@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of method {target} not currently supported!"
)

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
from .op_evaluators import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
88 changes: 64 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -12,8 +13,6 @@
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt

from .converter_registry import dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,13 +75,13 @@ def aten_ops_div(
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
):
kwargs_new["input"] = cast_trt_tensor(
network, kwargs_new["input"], trt.float32, name
network, kwargs_new["input"], trt.float32, name, target
)
elif isinstance(args[1], TRTTensor) and (
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
):
kwargs_new["other"] = cast_trt_tensor(
network, kwargs_new["other"], trt.float32, name
network, kwargs_new["other"], trt.float32, name, target
)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
Expand All @@ -102,22 +101,8 @@ def aten_ops_div(


def embedding_param_validator(embedding_node: Node) -> bool:
max_norm = args_bounds_check(embedding_node.args, 2)
norm_type = args_bounds_check(embedding_node.args, 3)
scale_grad_by_freq = args_bounds_check(embedding_node.args, 4)
sparse = args_bounds_check(embedding_node.args, 5)

if max_norm is not None:
_LOGGER.debug(
f"Currently we don't support specifying max_norm, got {max_norm}."
)
return False

if norm_type is not None and norm_type != 2.0:
_LOGGER.debug(
f"Currently we don't support specifying norm_type, got {norm_type}."
)
return False
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

if scale_grad_by_freq is not None:
_LOGGER.debug(
Expand Down Expand Up @@ -149,10 +134,9 @@ def aten_ops_embedding(
name,
input=args[1],
weight=args[0],
max_norm=args_bounds_check(args, 2),
norm_type=args_bounds_check(args, 3),
scale_grad_by_freq=args_bounds_check(args, 4),
sparse=args_bounds_check(args, 5),
# args[2] is the padding index, which is useful for training only
scale_grad_by_freq=args_bounds_check(args, 3),
sparse=args_bounds_check(args, 4),
)


Expand Down Expand Up @@ -380,3 +364,59 @@ def aten_ops_permute(
args[0],
args[1],
)


def to_copy_dtype_validator(to_copy_node: Node) -> bool:
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}

# Validate input node has convertible kwargs
if "dtype" in to_copy_node.kwargs:
if to_copy_node.kwargs["dtype"] in allowed_casts:
return True
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
)
return False
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
)
return False


@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
)
def aten_ops_to_copy_dtype(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.to_copy(
network,
target,
SourceIR.ATEN,
name,
args[0],
kwargs["dtype"],
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
def aten_ops_clone(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.clone(
network,
target,
SourceIR.ATEN,
name,
args[0],
)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dynamo_tensorrt_converter(
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
) -> Callable[[Any], Any]:
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:
"""Decorator for Dynamo TensorRT Converter
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
Expand Down
18 changes: 15 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
import re
from typing import List
from typing import List, Optional

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

from .._SourceIR import SourceIR
from .converter_registry import ConverterRegistry

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -71,24 +75,32 @@ def cast_trt_tensor(
input_val: TRTTensor,
dtype: TRTDataType,
name: str,
target: Target = "",
source_ir: Optional[SourceIR] = None,
) -> TRTTensor:
"""
Given a TRT Tensor, convert that Tensor to the specified dtype
Adds an Identity layer to the network which performs the conversion
Args:
network (TRTNetwork): A TensorRT network
input_val (TRTTensor): A TRT Tensor to cast to a new data type
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
name (str): Name of the calling layer
target (Target): Target of calling node
source_ir (SourceIR): SourceIR of calling converter
Returns:
A TensorRT ITensor which has been casted to the specified dtype
"""
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)

if input_val.dtype != trt_dtype:
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
target_str = ConverterRegistry.qualified_name_or_str(target)
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"

identity_layer = network.add_identity(input_val)
identity_layer.set_output_type(0, trt_dtype)
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
return identity_layer.get_output(0)
else:
return input_val
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import (
activation,
cast,
condition,
elementwise,
embedding,
Expand Down
43 changes: 43 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

LOGGER: logging.Logger = logging.getLogger(__name__)


def to_copy(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dtype: TRTDataType,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"to_copy received input {input} that is not a TensorRT ITensor"
)

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Callable, Optional, Union

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -15,8 +16,6 @@
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import tensorrt as trt


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down Expand Up @@ -132,9 +131,13 @@ def convert_binary_elementwise(
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)

if trt_promoted_type != lhs_val.dtype:
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
lhs_val = cast_trt_tensor(
network, lhs_val, trt_promoted_type, name, target, source_ir
)
if trt_promoted_type != rhs_val.dtype:
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
rhs_val = cast_trt_tensor(
network, rhs_val, trt_promoted_type, name, target, source_ir
)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
Expand Down
19 changes: 1 addition & 18 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,9 @@ def embedding(
name: str,
input: TRTTensor,
weight: TRTTensor,
max_norm: None,
norm_type: None,
scale_grad_by_freq: bool,
sparse: bool,
) -> TRTTensor:
if network.has_implicit_batch_dimension:
raise RuntimeError(
"The `embedding` function should be called with explicit batch dimension."
)

indices_tensor = input
embedding_tensor = weight
if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64:
Expand All @@ -37,16 +30,6 @@ def embedding(
# unsupported parameters
# ignore padding_idx since it is meaningful for training only

if max_norm is not None:
raise RuntimeError(
f"Currently we don't support specifying max_norm, got {max_norm}."
)

if norm_type is not None and norm_type != 2.0:
raise RuntimeError(
f"Currently we don't support specifying max_norm, got {norm_type} for norm_type."
)

if scale_grad_by_freq:
raise RuntimeError(
"Currently we don't support scale gradient by word frequency."
Expand All @@ -57,5 +40,5 @@ def embedding(

# Implement embedding lookup with gather layer
gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0)
set_layer_name(gather_layer, target, name + "_gather")
set_layer_name(gather_layer, target, name + "_gather", source_ir)
return gather_layer.get_output(0)
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/op_evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import operator
from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)


def getitem_validator(getitem_node: Node) -> bool:
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
def generic_evaluator(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args)
Loading

0 comments on commit 1ff46b6

Please sign in to comment.