From a0f6b07436edbec1652165ed83ff00b83a9af6ef Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 5 Dec 2023 14:08:07 -0800 Subject: [PATCH] select_scatter decomp --- .../dynamo/lowering/_decompositions.py | 15 +++++ .../py/dynamo/lowering/test_decompositions.py | 64 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 981c80f9fa..66d8bef523 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -162,6 +162,21 @@ def var_decomposition( return variance +@register_torch_trt_decomposition( + torch.ops.select_scatter, registry=TORCH_TRT_DECOMPOSITIONS +) +def select_scatter_decomposition( + input_tensor: torch.Tensor, + src_tensor: torch.Tensor, + dim: int, + index: int, +) -> torch.Tensor: + input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim]) + src_tensor = torch.expand(torch.unsqueeze(src_tensor, dim), input_tensor.shape) + input_tensor_shape = input_tensor.shape + return torch.where(torch.eq((input_tensor_shape[dim]), index)), src_tensor, input_tensor) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 84e8d11585..99ded414a7 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -420,6 +420,70 @@ def forward(self, x): f"MaxPool3d TRT outputs don't match with the original model.", ) + def test_lowering_select_scatter_module(self): + class selectScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src, dim, index): + y = self.select_scatter(x, src, dim, index) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = { + torch.ops.aten.lt.default, + torch.ops.aten.expand.default, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.where.default, + } + unexpected_ops = {torch.ops.aten.select_scatter} + + inputs = [torch.randn(2, 2), torch.ones(2)] + + fx_graph = torch.fx.symbolic_trace(selectScatter()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Select_scatter TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()