diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 38ab41fae8..ea22483a36 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -213,6 +213,19 @@ def slice_scatter_decomposition( return output_tensor +@register_torch_trt_decomposition( + torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def select_scatter_decomposition( + input_tensor: torch.Tensor, + src_tensor: torch.Tensor, + dim: int, + index: int, +) -> torch.Tensor: + src_tensor = torch.unsqueeze(src_tensor, dim) + return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index e248f1528a..8a963e24ef 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1): "torch_compile", inputs, min_block_size=1, - truncate_long_and_double=True, + truncate_double=True, pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step): "torch_compile", inputs, min_block_size=1, - truncate_long_and_double=True, + truncate_double=True, pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step): "torch_compile", inputs, min_block_size=1, - truncate_long_and_double=True, + truncate_double=True, pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step): f"Slice_scatter TRT outputs don't match with the original model.", ) + def test_lowering_select_scatter_dimZero_module(self): + class selectScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src, dim, index): + y = torch.ops.aten.select_scatter.default(x, src, dim, index) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default} + unexpected_ops = { + torch.ops.aten.select_scatter.default, + torch.ops.aten.slice_scatter.default, + } + + inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0] + + fx_graph = torch.fx.symbolic_trace(selectScatter()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_and_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Select_scatter TRT outputs don't match with the original model.", + ) + + def test_lowering_select_scatter_dimOne_module(self): + class selectScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src, dim, index): + y = torch.ops.aten.select_scatter.default(x, src, dim, index) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default} + unexpected_ops = { + torch.ops.aten.select_scatter.default, + torch.ops.aten.slice_scatter.default, + } + + inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0] + + fx_graph = torch.fx.symbolic_trace(selectScatter()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Select_scatter TRT outputs don't match with the original model.", + ) + + def test_lowering_select_scatter_multidimension_module(self): + class selectScatter(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x, src, dim, index): + y = torch.ops.aten.select_scatter.default(x, src, dim, index) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default} + unexpected_ops = { + torch.ops.aten.select_scatter.default, + torch.ops.aten.slice_scatter.default, + } + + inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0] + + fx_graph = torch.fx.symbolic_trace(selectScatter()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Select_scatter TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests()