Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather Implementation #2457

Closed
wants to merge 11 commits into from
Closed

Gather Implementation #2457

wants to merge 11 commits into from

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Nov 13, 2023

Fixes #2202

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 13, 2023
@apbose apbose mentioned this pull request Nov 13, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-11-13 20:18:38.564805+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2023-11-13 20:20:27.102418+00:00
@@ -202,11 +202,11 @@
        SourceIR.ATEN,
        name,
        input=args[0],
        dim=args[1],
        index=args[2],
-        sparse_grad = args_bounds_check(args, 4, False),
+        sparse_grad=args_bounds_check(args, 4, False),
    )


@dynamo_tensorrt_converter(torch.ops.aten.group_norm.default)
@dynamo_tensorrt_converter(torch.ops.aten.group_norm)

@apbose apbose changed the title Expose IGridSampleLayer Gather Implementation Nov 13, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few comments, and just needs some test cases for the new gather converter

input=args[0],
dim=args[1],
index=args[2],
sparse_grad=args_bounds_check(args, 4, False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The impl.select.gather function cannot take this argument as input, and since it seems to not be used, it can be omitted here entirely.

{
0: (TRTTensor,),
}
) # type: ignore[misc]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The # type: ignore can be removed.

@@ -68,13 +68,26 @@ def select(
indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
out = gather(input, indices_tensor, dim)
if len(out.shape) != 1:
layer = ctx.net.add_shuffle(out)
return layer.get_output(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line (return layer.get_output(0)) will fail since layer is no longer defined. It should now return out, and line 73 above it would need to be changed to have something like out = layer.get_output(0)

gather_layer = ctx.net.add_gather(input, indices_tensor, index)
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
return gather(input, index, indices_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call is missing inputs (ctx, target, etc.)

gather_layer_element, target, name + "_index_gather_element", source_ir
)
gather_out = gather_layer_element.get_output(0)
gather_out = gather(flatten_tensor, cum_adv_index, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call is missing inputs (ctx, target, etc.)

@@ -68,13 +68,26 @@ def select(
indices_tensor = ctx.net.add_constant(
index_value.shape, to_numpy(index_value)
).get_output(0)
layer = ctx.net.add_gather(input, indices_tensor, dim)
out = layer.get_output(0)
out = gather(input, indices_tensor, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call is missing inputs (ctx, target, etc.)

@apbose apbose marked this pull request as draft November 17, 2023 22:08
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-11-28 09:07:07.522513+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/select.py	2023-11-28 09:08:52.705467+00:00
@@ -251,11 +251,13 @@
                    trt.ElementWiseOperation.PROD,
                    multiplier,
                    dim_tensor_list[adv_indx_indices[i]],
                )

-        gather_out = gather(ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index)
+        gather_out = gather(
+            ctx, target, source_ir, name, flatten_tensor, 0, cum_adv_index
+        )
        _LOGGER.debug(f"The shape after cumultative gather is {gather_out.shape}")
        _LOGGER.debug(f"The shape for cumulative adv index is {cum_adv_index}")

        cum_adv_index_shape_layer = ctx.net.add_shape(cum_adv_index)
        set_layer_name(

@github-actions github-actions bot added the component: tests Issues re: Tests label Nov 28, 2023
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_gather_aten.py	2023-11-28 22:03:18.634373+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_gather_aten.py	2023-11-28 22:05:04.769261+00:00
@@ -24,7 +24,8 @@
        self.run_test(
            TestModule(),
            input,
        )

+
if __name__ == "__main__":
-    run_tests()
\ No newline at end of file
+    run_tests()

@apbose apbose marked this pull request as ready for review November 28, 2023 22:43
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_gather_aten.py	2023-11-28 22:43:50.947029+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_gather_aten.py	2023-11-28 22:45:39.549529+00:00
@@ -24,7 +24,8 @@
        self.run_test(
            TestModule(),
            input,
        )

+
if __name__ == "__main__":
-    run_tests()
\ No newline at end of file
+    run_tests()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_gather_aten.py	2023-12-12 22:13:26.437196+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_gather_aten.py	2023-12-12 22:15:21.040957+00:00
@@ -24,7 +24,8 @@
        self.run_test(
            TestModule(),
            input,
        )

+
if __name__ == "__main__":
-    run_tests()
\ No newline at end of file
+    run_tests()

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@gs-olive
Copy link
Collaborator

gs-olive commented Jan 5, 2024

The Dynamo paths both use a truncation mechanism for int64 inputs which takes effect prior to the TRTInterpreter to avoid that issue. Does Int32 not work for the test cases?

@apbose
Copy link
Collaborator Author

apbose commented Jan 8, 2024

@gs-olive int32 in the test case leads to this error- RuntimeError: gather(): Expected dtype int64 for index

@gs-olive
Copy link
Collaborator

gs-olive commented Jan 8, 2024

Ok got it - could the test then mimic the scheme used here:

class Test64BitInput(TestCase):

This uses torch.compile directly to allow the 64-bit repair code to run.

@apbose
Copy link
Collaborator Author

apbose commented Jan 12, 2024

Hi @gs-olive
In the test case below:

class TestGatherConverter(DispatchTestCase):
    def test_gather_zero_two_dim(self):
        class TestModule(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, indices):
                # index0 = torch.randint(0, 1, (1, 1))
                out = torch.ops.aten.gather.default(x, 0, indices)
                return out

        index0 = torch.randint(0, 1, (1, 1), dtype=torch.int64)
        inputs = [torch.randn(2, 2), index0]
        #index0 = torch.randint(0, 1, (1, 1))
        fx_graph = torch.fx.symbolic_trace(TestModule())
        torch._dynamo.reset()

        optimized_model = torch_tensorrt.compile(fx_graph, 
            "torch_compile",
            inputs,
            min_block_size=1,
            pass_through_build_failures=True,
            truncate_long_and_double=True,
            debug=True,
        )
        optimized_model_results = optimized_model(*inputs).detach().cpu()
        torch_model_results = fx_graph(*inputs).detach().cpu()
        max_diff = float(
            torch.max(torch.abs(optimized_model_results - torch_model_results))
        )
        self.assertAlmostEqual(
            max_diff,
            0,
            4,
            f"TRT outputs don't match with the original model.",
        )

The two cases

  1. If truncate_long_and_double is kept as true, the error is gather(): Expected dtype int64 for index, but got torch.int32. That is because here would convert the int64 to int32, when the input reaches optimized_model_results = optimized_model(*inputs).detach().cpu()
  2. I don't think truncate_long_and_double=False would be a way since TensorRT would not accept int64 input
    Am I missing something in the workaround? Thanks.

@gs-olive
Copy link
Collaborator

2 Test Cases [ITensor + constant input types]

  • Indices as input [Differing Semantics between PyTorch and TensorRT --> PyTorch gets Int64, TRT gets Int32]
    • Known operator to assemble indices (i.e. torch.arange, or other constant-producing layer)
  • Inlined constant indices test case as [Int64 Input Type]
"""
graph(x, indices):
    --> inserted cast int64 to int32
    # If running in PyTorch cast back to int64
    out = aten.gather(x, 0, indices)

    return out
"""

For a Later PR/Issue:

  • Avoid running PyTorch graphs with invalid casts
  • Refactor repair_trunca... to consume output of type inference
  • Reorder truncation to come as late as possible in the compilation process

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-02-20 20:41:10.621581+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-02-20 20:42:59.737772+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-02-20 20:41:10.625581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-02-20 20:42:59.835060+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-02-20 20:41:10.625581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-02-20 20:43:00.050910+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-02-20 20:41:10.629581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-02-20 20:43:00.091593+00:00
@@ -208,13 +208,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else [],
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else []
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-02-20 20:41:10.629581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-02-20 20:43:00.291320+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-02-20 20:41:10.629581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-02-20 20:43:00.369509+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-02-20 20:41:10.629581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-02-20 20:43:00.709296+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-02-20 20:41:10.629581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-02-20 20:43:00.711596+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -178,13 +178,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-02-20 20:41:10.629581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py	2024-02-20 20:43:00.720628+00:00
@@ -25,16 +25,14 @@
        )

    return gm


-def efficient_attention_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def efficient_attention_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Original graph
    def orig(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-02-20 20:43:00.746806+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py	2024-02-20 20:43:00.858775+00:00
@@ -215,13 +215,13 @@
        require_full_compilation: Whether to require that all operators be run in TRT
    Returns:
        torch.fx.GraphModule, TorchTensorRTOperatorSupport
    """
    supported_ops = TorchTensorRTOperatorSupport(
-        torch_executed_ops=torch_executed_ops
-        if torch_executed_ops is not None
-        else set()
+        torch_executed_ops=(
+            torch_executed_ops if torch_executed_ops is not None else set()
+        )
    )
    partitioner = TRTPartitioner(
        gm,
        supported_ops,
        min_block_size=min_block_size,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-02-20 20:43:00.970857+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-02-20 20:43:01.339768+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-02-20 20:43:01.398524+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-02-20 20:43:01.398550+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-02-20 20:43:01.618752+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-02-20 20:43:01.835317+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-02-20 20:41:10.633581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-02-20 20:43:01.900721+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-02-20 20:41:10.637581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-02-20 20:43:02.295121+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-02-20 20:41:10.641581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-02-20 20:43:02.961258+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-02-20 20:41:10.641581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-02-20 20:43:03.288540+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-02-20 20:41:10.641581+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-02-20 20:43:03.610719+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

@apbose
Copy link
Collaborator Author

apbose commented Feb 20, 2024

I removed the gather test here, since the above task was to expose the gather layer.
The TensorRT output for torch.ops.aten.gather and ctx.net.add_gather is different. Example

class gather(torch.nn.Module):
    def __init__(self, indices):
        self.indices = indices
        super().__init__()

    def forward(self, x):
        out = torch.ops.aten.gather.default(x, 0, self.indices)
        return out

index0 = torch.randint(0, 1, (1, 1), dtype=torch.int64).cuda()
inputs = torch.randn(2, 2).cuda()
gatherOut = gather(index0)
outTest = gatherOut(inputs)

index0 = [0]
dim = 0
Returns inputs[0,0], whereas the gather_layer would return [inputs[0,0], inputs[0,1]] (the zeroth tensor along the zeroth dimension
The expose of gather layer would be tested by aten_index and aten_select tests
Implementation of gather converter would be a separate task based on https://pytorch.org/docs/stable/generated/torch.gather.html

input=args[0],
dim=args[1],
index=args[2],
sparse_grad=args_bounds_check(args, 3, False),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the schema here, sparse_grad would be in kwargs. Additionally, since it seems to have no effect in the gather converter below, it can be removed/ignored, or a validator can be used to ensure it is False.

gather_layer = ctx.net.add_gather(input, indices_tensor, index)
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
return gather_layer.get_output(0)
return gather(ctx, target, source_ir, name, input, index, indices_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 126/147 needs to be renamed since it uses name + f"_parameter_to_fp32_tensor", which also appears in the gather function. This could cause a duplicate name error in edge cases

input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
sparse_grad: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above, can remove if never used.

Comment on lines +87 to +90
# This is for the case where torch.ops.aten.gather requires torch.int64
# However TRTInterpreter complains that torch.int64 is not a supported type
# So the below cast does not help
# index = cast_trt_tensor(ctx, input, trt.int32, name, target, source_ir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this issue still occur in the test cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does. aten.scatter has similar use cases, so I am working on that. The casting of nodes in the TRT test infrastructure can be used here to get over. This is the PR- #2664.

@apbose apbose marked this pull request as draft March 8, 2024 00:53
@narendasan narendasan closed this Jun 7, 2024
@narendasan narendasan deleted the gather_impl_new branch June 7, 2024 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose IGridSampleLayer in dynamo.conversion.impl
4 participants