Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

select_scatter decomp #2515

Merged
merged 1 commit into from
May 31, 2024
Merged

select_scatter decomp #2515

merged 1 commit into from
May 31, 2024

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Dec 5, 2023

Fixes #2436
This PR would be dependant on #2519, #2664 and #2669. Major changes

2519- Decomposition of aten::slice_scatter
2664- Implementation makes use of aten::scatter.src
2669- Constants getting converted to fake tensors in get_attr call due to which different device location meta and cpu in torch

@apbose apbose self-assigned this Dec 5, 2023
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests labels Dec 5, 2023
@apbose apbose marked this pull request as draft December 5, 2023 22:04
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose apbose force-pushed the select_scatter_decomposition branch from e4c56cd to 037fbcf Compare December 29, 2023 06:21
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-01-02 18:24:49.853008+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-01-02 18:27:03.483949+00:00
@@ -2,10 +2,11 @@
from torch.testing._internal.common_utils import TestCase, run_tests

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
+

class TestLowering(TestCase):
    def test_lowering_inplace_op(self):
        class InPlace(torch.nn.Module):
            def __init__(self, *args, **kwargs) -> None:

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@apbose apbose requested a review from gs-olive January 5, 2024 18:13
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-01-05 18:29:23.300495+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-01-05 18:31:16.737105+00:00
@@ -481,11 +481,10 @@
            0,
            DECIMALS_OF_AGREEMENT,
            f"Select_scatter TRT outputs don't match with the original model.",
        )

-
    def test_lowering_select_scatter_dimOne_module(self):
        class selectScatter(torch.nn.Module):
            def __init__(self, *args, **kwargs) -> None:
                super().__init__(*args, **kwargs)

@@ -544,7 +543,9 @@
            max_diff,
            0,
            DECIMALS_OF_AGREEMENT,
            f"Select_scatter TRT outputs don't match with the original model.",
        )
+
+
if __name__ == "__main__":
    run_tests()

@apbose apbose force-pushed the select_scatter_decomposition branch from 81a2715 to d5cec9f Compare January 12, 2024 17:08
@apbose apbose marked this pull request as ready for review January 12, 2024 18:53
@apbose apbose force-pushed the select_scatter_decomposition branch from d5cec9f to 689105e Compare January 16, 2024 17:57
Comment on lines 174 to 186
# input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
# check if the dim is less than shape
if input_tensor.shape[dim] < index:
raise AssertionError("The index should not be greater than dim")

# expanding the src_tensor to have the same dimension as input_tensor
# check if the dimension of the src tensor is same as slice tensor
select_tensor = torch.select(input_tensor, dim, index)

if select_tensor.shape != src_tensor.shape:
raise AssertionError(
"The slice tensor shape should be equal to the src tensor shape"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the AssertionError cases invalid in Torch, or just cases we can't support? Having AssertionErrors in lowering passes can cause models to inexplicably fail for users, so it is not preferable.

If these are invalid cases in Torch itself, then we do not need these assertions. If they are not supported by TRT, then we can instead return the original op and not lower.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is invalid test case in torch itself. So I think the right thing would be to do away with the assertion.
When you say return the original op, that would mean in those cases, we just do

if(condition == True):
     return <unlowered_original_op>

Copy link
Collaborator

@gs-olive gs-olive Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for that clarification. I meant that if it would be invalid in Torch, you can assume it will not be the case that the condition would ever be encountered - otherwise the model should have failed earlier. Specifically, if select_tensor.shape == src_tensor.shape is a requirement of select_scatter, then it is safe to assume the inputs are valid inputs to that function, otherwise we can let Torch throw the error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it. Thanks for the clarification!

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-02-16 00:01:27.167252+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py	2024-02-16 00:03:16.025977+00:00
@@ -1,10 +1,11 @@
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
    https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py	2024-02-16 00:03:16.122237+00:00
@@ -30,16 +30,18 @@
        gpu_id (int): Device ID for target GPU
        dla_core (int): Core ID for target DLA core
        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
    """

-    device_type: Optional[
-        trt.DeviceType
-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    device_type: Optional[trt.DeviceType] = (
+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
+    )
    gpu_id: int = -1  #: Device ID for target GPU
    dla_core: int = -1  #: Core ID for target DLA core
-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    allow_gpu_fallback: bool = (
+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed
+    )

    def __init__(self, *args: Any, **kwargs: Any):
        """__init__ Method for torch_tensorrt.Device

        Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py	2024-02-16 00:03:16.328118+00:00
@@ -26,16 +26,16 @@

    class _ShapeMode(Enum):
        STATIC = 0
        DYNAMIC = 1

-    shape_mode: Optional[
-        _ShapeMode
-    ] = None  #: Is input statically or dynamically shaped
-    shape: Optional[
-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]
-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    shape_mode: Optional[_ShapeMode] = (
+        None  #: Is input statically or dynamically shaped
+    )
+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
+    )
    dtype: _enums.dtype = (
        _enums.dtype.unknown
    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
    _explicit_set_dtype: bool = False
    format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py	2024-02-16 00:03:16.375728+00:00
@@ -212,13 +212,13 @@
        "precision": precision,
        "debug": debug,
        "device": device,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
-        "torch_executed_ops": torch_executed_ops
-        if torch_executed_ops is not None
-        else set(),
+        "torch_executed_ops": (
+            torch_executed_ops if torch_executed_ops is not None else set()
+        ),
        "pass_through_build_failures": pass_through_build_failures,
        "max_aux_streams": max_aux_streams,
        "version_compatible": version_compatible,
        "optimization_level": optimization_level,
        "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-02-16 00:01:27.175252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2024-02-16 00:03:16.569143+00:00
@@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class UnsupportedOperatorException(RuntimeError):
    pass

@@ -90,13 +90,13 @@
        self.input_specs_iter = 0
        self._cur_node_name: Optional[str] = None
        self._cur_node: Optional[torch.fx.Node] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )
        self.compilation_settings = compilation_settings

        # Data types for TRT Module output Tensors
        self.output_dtypes = output_dtypes

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py	2024-02-16 00:03:16.647465+00:00
@@ -322,17 +322,15 @@
    else:
        raise AssertionError(f"Cannot convert {input_val} to TRT constant")


@overload
-def get_positive_dim(dim: int, dim_size: int) -> int:
-    ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
-    ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
    dim: Union[int, Sequence[int]], dim_size: int
) -> Union[int, Tuple[int, ...]]:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py	2024-02-16 00:03:17.010073+00:00
@@ -5,13 +5,13 @@
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[
-    OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions()
+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
+    core_aten_decompositions()
+)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._adaptive_avg_pool2d_backward,
    aten.addcdiv,
    aten.addcdiv_,
    aten.addcmul,
@@ -179,13 +179,13 @@
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
    aten._softmax.default,
}


-ENABLED_TORCH_DECOMPOSITIONS: Dict[
-    OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions)
+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
+    get_torch_decompositions(torch_enabled_decompositions)
+)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
    """Validates no overlap between enabled and disabled decomposition sets"""
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py	2024-02-16 00:03:17.018926+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after lowering linear:\n{gm.graph}")

    return gm


-def linear_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def linear_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for linear"""

    # Original graph
    def orig(
        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py	2024-02-16 00:03:17.052927+00:00
@@ -20,16 +20,14 @@
        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

    return gm


-def view_replacement() -> (
-    Tuple[
-        torch.fx.GraphModule,
-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
-    ]
-):
+def view_replacement() -> Tuple[
+    torch.fx.GraphModule,
+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]:
    """Constructs the original and replacement functions for view"""

    # Original graph
    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
        return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py	2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py	2024-02-16 00:03:17.057189+00:00
@@ -58,16 +58,14 @@
        logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")

    return gm


-def scaled_dot_product_attention_replacement() -> (
-    Tuple[
-        Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
-    ]
-):
+def scaled_dot_product_attention_replacement() -> Tuple[
+    Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]:
    """Constructs the original and replacement functions for efficient attention"""

    # Efficient Attention original graph
    def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-02-16 00:01:27.179252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2024-02-16 00:03:17.277499+00:00
@@ -99,25 +99,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self) -> None:
        if not self.initialized:
@@ -165,13 +169,15 @@
        self.__dict__.update(state)
        if self.engine:
            self.context = self.engine.create_execution_context()

    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
-        with torch.autograd.profiler.record_function(
-            "PythonTorchTensorRTModule:Forward"
-        ) if self.profiling_enabled else nullcontext():
+        with (
+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
+            if self.profiling_enabled
+            else nullcontext()
+        ):
            self._check_initialized()

            # If in safe mode, check at each iteration for for whether a switch is required
            if (
                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@
                    torch.cuda.set_device(device_id)

                    inputs = tuple([tensor.to(device) for tensor in inputs])
                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessInputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessInputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                assert len(inputs) == len(
                    self.input_names
                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@

                    self.context.set_binding_shape(
                        idx, tuple(contiguous_inputs[i].shape)
                    )

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:ProcessOutputs"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:ProcessOutputs"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                # create output tensors
                outputs: List[torch.Tensor] = []

                for i, idx in enumerate(self.output_binding_indices_in_order):
                    shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@
                        dtype=self.hidden_output_dtypes[i],
                        device=torch.cuda.current_device(),
                    )
                    bindings[idx] = output.data_ptr()

-            with torch.autograd.profiler.record_function(
-                "PythonTorchTensorRTModule:TensorRTRuntime"
-            ) if self.profiling_enabled else nullcontext():
+            with (
+                torch.autograd.profiler.record_function(
+                    "PythonTorchTensorRTModule:TensorRTRuntime"
+                )
+                if self.profiling_enabled
+                else nullcontext()
+            ):
                self.context.execute_async_v2(
                    bindings, torch.cuda.current_stream().cuda_stream
                )

            if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py	2024-02-16 00:03:17.622507+00:00
@@ -315,25 +315,21 @@
    name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
    kwargs_new = {
        "input": args[0],
        "kernel_size": args[1],
-        "stride": args[2]
-        if len(args) > 2
-        else (None, None)
-        if len(args[1]) == 2
-        else (None, None, None),
-        "padding": args[3]
-        if len(args) > 3
-        else (0, 0)
-        if len(args[1]) == 2
-        else (0, 0, 0),
-        "dilation": args[4]
-        if len(args) > 4
-        else (1, 1)
-        if len(args[1]) == 2
-        else (1, 1, 1),
+        "stride": (
+            args[2]
+            if len(args) > 2
+            else (None, None) if len(args[1]) == 2 else (None, None, None)
+        ),
+        "padding": (
+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
+        ),
+        "dilation": (
+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
+        ),
        "ceil_mode": args[5] if len(args) > 5 else False,
    }
    return acc_ops_converters.acc_ops_max_poolnd(
        network, target, None, kwargs_new, name
    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py	2024-02-16 00:03:17.675873+00:00
@@ -19,13 +19,13 @@
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(__name__)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
-    Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)


class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
@@ -73,13 +73,13 @@
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
-        self._itensor_to_tensor_meta: Dict[
-            trt.tensorrt.ITensor, TensorMetadata
-        ] = dict()
+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
+            dict()
+        )

    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py	2024-02-16 00:03:17.684340+00:00
@@ -124,25 +124,29 @@
        interpreter = TRTInterpreter(
            mod,
            input_specs=self.lower_setting.input_specs,
            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
            explicit_precision=self.lower_setting.explicit_precision,
-            logger_level=trt.Logger.VERBOSE
-            if self.lower_setting.verbose_log
-            else trt.Logger.WARNING,
+            logger_level=(
+                trt.Logger.VERBOSE
+                if self.lower_setting.verbose_log
+                else trt.Logger.WARNING
+            ),
        )

        interp_result: TRTInterpreterResult = interpreter.run(
            max_batch_size=self.lower_setting.max_batch_size,
            max_workspace_size=self.lower_setting.max_workspace_size,
            lower_precision=self.lower_setting.lower_precision,
            strict_type_constraints=self.lower_setting.strict_type_constraints,
            algorithm_selector=algo_selector,
            timing_cache=cache_data,
-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED
-            if self.lower_setting.verbose_profile
-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
+            profiling_verbosity=(
+                trt.ProfilingVerbosity.DETAILED
+                if self.lower_setting.verbose_profile
+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
+            ),
            tactic_sources=self.lower_setting.tactic_sources,
        )

        # Update timing cache file if needed
        timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@
                module.half()
                # A custom conversion function can be passed to the lowerer to
                # handle inputs with custom types. By default, just handle
                # tensors and NoneType.
                if fp16_conversion_fn is None:
-                    conversion_fn = (
-                        lambda x: x.half()
-                        if x is not None and x.dtype == torch.float32
-                        else x
+                    conversion_fn = lambda x: (
+                        x.half() if x is not None and x.dtype == torch.float32 else x
                    )
                else:
                    conversion_fn = fp16_conversion_fn

                inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py	2024-02-16 00:03:17.896029+00:00
@@ -194,13 +194,15 @@
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.input_specs = generate_input_specs(
                        submod_inputs,
                        self.lower_setting,
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )
                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
                    setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@
                if not submod_name.startswith(split_result.non_acc_submodule_prefix):
                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")
                    lowering_start_time = datetime.datetime.now()

                    self.lower_setting.additional_inputs = (
-                        additional_submodule_inputs[submod_name]
-                        if additional_submodule_inputs
-                        else None,
+                        (
+                            additional_submodule_inputs[submod_name]
+                            if additional_submodule_inputs
+                            else None
+                        ),
                    )

                    lowered_module = self._lower_func(
                        submod, submod_inputs, self.lower_setting, submod_name
                    )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py	2024-02-16 00:03:18.123580+00:00
@@ -193,13 +193,11 @@
                kwargs2 = {"equal_nan": True}
                if rtol:
                    kwargs2["rtol"] = rtol
                if atol:
                    kwargs2["atol"] = atol
-                kwargs2[
-                    "msg"
-                ] = (
+                kwargs2["msg"] = (
                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
                )
                # If tensors are on different devices, make sure to compare
                # their copies that are on the same device.
                if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-02-16 00:01:27.183252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py	2024-02-16 00:03:18.166433+00:00
@@ -536,13 +536,13 @@
        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
            maybe_reshape
        )
        if not reshape_batch_size:
            continue
-        reshape_batch_size_inferred_source: Optional[
-            fx.Node
-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)
+        reshape_batch_size_inferred_source: Optional[fx.Node] = (
+            get_reshape_batch_size_inferred_source(reshape_batch_size)
+        )
        if not reshape_batch_size_inferred_source:
            continue

        reshape_input: fx.Node = maybe_reshape.kwargs["input"]
        if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-02-16 00:01:27.187252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py	2024-02-16 00:03:18.592970+00:00
@@ -21,13 +21,15 @@
        inputs = [torch.randn(1, 10)]
        self.run_test(
            Split(),
            inputs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
            test_explicit_batch_dim=False,
        )

    @parameterized.expand(
@@ -68,13 +70,15 @@
        ]
        self.run_test_with_dynamic_shape(
            Split(),
            input_specs,
            expected_ops={
-                acc_ops.split
-                if isinstance(split_size_or_sections, int)
-                else acc_ops.slice_tensor
+                (
+                    acc_ops.split
+                    if isinstance(split_size_or_sections, int)
+                    else acc_ops.slice_tensor
+                )
            },
        )

    # Testing with (-1, -1, -1) results into following error:
    # AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-02-16 00:01:27.191252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py	2024-02-16 00:03:19.259154+00:00
@@ -152,13 +152,13 @@
            mod.eval()
            if len(expected_ops):
                self.assert_has_op(mod, expected_ops)

            interpreter_result = interpreter.run(
-                lower_precision=LowerPrecision.FP16
-                if fp16_mode
-                else LowerPrecision.FP32
+                lower_precision=(
+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
+                )
            )
            trt_mod = TRTModule(
                interpreter_result.engine,
                interpreter_result.input_names,
                interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-02-16 00:01:27.191252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py	2024-02-16 00:03:19.609670+00:00
@@ -67,25 +67,29 @@
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.output_binding_indices_in_order
        ]
        self.output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.output_binding_indices_in_order
        ]
        self.hidden_output_dtypes: Sequence[torch.dtype] = [
            unified_dtype_converter(
                self.engine.get_binding_dtype(idx), Frameworks.TORCH
            )
            for idx in self.hidden_output_binding_indices_in_order
        ]
        self.hidden_output_shapes = [
-            tuple(self.engine.get_binding_shape(idx))
-            if self.engine.has_implicit_batch_dimension
-            else tuple()
+            (
+                tuple(self.engine.get_binding_shape(idx))
+                if self.engine.has_implicit_batch_dimension
+                else tuple()
+            )
            for idx in self.hidden_output_binding_indices_in_order
        ]

    def _check_initialized(self):
        if not self.initialized:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-02-16 00:01:27.191252+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py	2024-02-16 00:03:19.911816+00:00
@@ -404,13 +404,13 @@
        "inputs": inputs if inputs is not None else [],
        # "input_signature": input_signature,
        "device": device,
        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.
-        "enabled_precisions": enabled_precisions
-        if enabled_precisions is not None
-        else set(),  # Enabling FP16 kernels
+        "enabled_precisions": (
+            enabled_precisions if enabled_precisions is not None else set()
+        ),  # Enabling FP16 kernels
        "refit": refit,  # enable refit
        "debug": debug,  # enable debuggable engine
        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels
        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels
        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

@apbose apbose force-pushed the select_scatter_decomposition branch from 150f055 to f1ff596 Compare February 16, 2024 00:03
@apbose apbose requested a review from gs-olive February 20, 2024 20:02
@apbose apbose force-pushed the select_scatter_decomposition branch 2 times, most recently from 2eaae77 to b5b45a1 Compare February 26, 2024 22:38
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good - just had one question, added below

Comment on lines 174 to 177
unbind_tensors = torch.unbind(input_tensor, dim)
unbind_tensors_list = list(unbind_tensors)
unbind_tensors_list[index] = src_tensor
return torch.stack(tuple(unbind_tensors_list), dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What operators in the graph does this generate after tracing? Is there a before/after sample that could be shared

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive, these were the graphs-
Pre-AOT Autograd graph:

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_src_ : torch.Tensor [num_users=1] = placeholder[target=L_src_]
    %select_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%l_x_, %l_src_, 0, 0), kwargs = {})
    return (select_scatter_default,)

Post-AOT Autograd graph:

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {})
    %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%clone_1, 0, 1, 2), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%slice_2, 0), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%clone, %squeeze_1],), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%cat, [2, 2]), kwargs = {})
    return (view,)
Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg0_1, 0, 1, 2), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%slice_2, 0), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%arg1_1, %squeeze_1],), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%cat, [2, 2]), kwargs = {})
    return (view,)

Post-lowering passes Autograd graph:

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg0_1, 0, 1, 2), kwargs = {})
    %squeeze_1 : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%slice_2, 0), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%arg1_1, %squeeze_1],), kwargs = {})
    %reshape_default : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [2, 2]), kwargs = {})
    return (reshape_default,)

However I have changed the implementation now to make use of slice_scatter implementation which I have updated in the description.

@apbose apbose requested a review from gs-olive March 19, 2024 08:54
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the old implementation was valid, but the new one does not seem to work in some cases, for example:

>>> import torch
>>> a = torch.zeros(2, 2)
>>> b = torch.ones(2)
>>> torch.select_scatter(a, b, 0, 0)
tensor([[1., 1.],
        [0., 0.]])
>>> torch.slice_scatter(a, b.unsqueeze(0), 0, 1, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: expected src to have a size equal to the slice of self. src size = [1, 2], slice size = [0, 2]

@gs-olive
Copy link
Collaborator

See this decomposition for an alternative approach.

@apbose
Copy link
Collaborator Author

apbose commented Mar 26, 2024

Thanks @gs-olive for pointing the above.
But I think the implementation using slice_scatter decomposition should also work in our case.
For eg: in the above case the unsqueeze dimension with src would lead to src_tensor being [1,2] (torch.slice_scatter would expect it to be [0,2]. But since in slice_scatter decomposition, we do away with dimension at dim (dim=0) in this case the above error would not come there. Also in more than two dimensions this case should never be encountered

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-03-26 20:46:17.748006+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-03-26 20:53:01.432160+00:00
@@ -607,7 +607,8 @@
            0,
            DECIMALS_OF_AGREEMENT,
            f"Select_scatter TRT outputs don't match with the original model.",
        )

+
if __name__ == "__main__":
    run_tests()

@gs-olive
Copy link
Collaborator

gs-olive commented Mar 27, 2024

So, in this case would the implementation not be functional without the slice_scatter decomposition?

Additionally, if the slice_scatter decomposition changes the behavior of torch.slice_scatter, in the sense that the example here (#2515 (review)) passes with the decomposition but fails without it, then how does the slice_scatter decomposition change the operator? I thought the decomposition would be 1:1 with the operator meaning any inputs to the operator are valid inputs to the decomposition and vice versa.

@apbose
Copy link
Collaborator Author

apbose commented Apr 2, 2024

I misread the case pointed by you.

>>> import torch
>>> a = torch.zeros(2, 2)
>>> b = torch.ones(2)
>>> torch.select_scatter(a, b, 0, 0)
tensor([[1., 1.],
        [0., 0.]])
>>> torch.slice_scatter(a, b.unsqueeze(0), 0, 1, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: expected src to have a size equal to the slice of self. src size = [1, 2], slice size = [0, 2]

In the above according to the implementation above, the slice_scatter op would be
torch.slice_scatter(a, b.unsqueeze(0), 0, 0, 1, 1)
Which would lead to the slice dimension being [1,2]
The difference between slice_scatter and select_scatter is that since select_scatter inserts a single dimensional tensor, while inputting the src tensor we generally just provide single dimension along the dim.
Example: for
torch.select_scatter=>
for input tensor torch.zeros((2,2)) with shape = [2,2] ,the src tensor should be torch.ones(2) with shape [2] instead of [1,2] for dim = 1
for input tensor torch.zeros((2,3,4)) with shape = [2,3,4], the src tensor should be torch.ones(2,4) with shape [2,4] for dim = 1
The op would be torch.select_scatter(input, src, dim, index)

torch.slice_scatter=>
for input tensor torch.zeros((2,2)) with shape = [2,2], the src tensor should be torch.ones(1,2) with shape [1,2] for dim = 1
for input tensor torch.zeros((2,3,4)) with shape = [2,3,4], the src tensor should be torch.ones(2, 1, 4) with shape [2,4] for dim = 1
The op would be torch.slice_scatter(input, src, dim, index, index+1)

To answer the above question-

  1. No slice_scatter decomposition does not alter the torch slice_scatter behavior. It is 1:1 behavior
  2. My earlier comment was because I misunderstood. The torch.slice_scatter would also expect it to be [0,1], since the slice_scatter is for torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1) (index to index+1). But as mentioned in the above comment, since we do away with the unsqueezed dimension in slice_scatter, I partially misunderstood the example of the slice_scatter op and thought that would work (basically the slice_scatter op would not come in the first place).

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, pending rebase + CI passing

@apbose apbose force-pushed the select_scatter_decomposition branch from dc9670f to 694befd Compare April 5, 2024 00:20
@github-actions github-actions bot requested a review from gs-olive April 19, 2024 00:55
@apbose apbose force-pushed the select_scatter_decomposition branch 2 times, most recently from da4d71d to f174eb1 Compare May 30, 2024 17:33
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-05-30 17:34:01.396563+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py	2024-05-30 17:35:56.410082+00:00
@@ -669,10 +669,10 @@
        self.assertAlmostEqual(
            max_diff,
            0,
            DECIMALS_OF_AGREEMENT,
            f"Select_scatter TRT outputs don't match with the original model.",
-        )    
+        )


if __name__ == "__main__":
    run_tests()

@apbose apbose force-pushed the select_scatter_decomposition branch 4 times, most recently from ee8330a to 004c56f Compare May 31, 2024 00:21
@apbose apbose merged commit 6583300 into main May 31, 2024
12 checks passed
Changing lowering of select_scatter

select_scatter changes

select_scatter changes

Test case for select_scatter

removing assertion

adding select_scatter decomp lowering ops in test

implement select_scatter using slice_scatter

adding test case

linting commit fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for aten.select_scatter
4 participants