From 6bae6d5f8d8855a9296aebbf924be2f9fd665b75 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Wed, 28 May 2025 15:04:05 -0700 Subject: [PATCH] Use single_op_builder in simplify unit tests. (#11158) Summary: Use single_op_builder in simplify unit tests. Reviewed By: hsharma35 Differential Revision: D75309572 --- .../aot/tests/test_simplify_ops_passes.py | 110 +++++------------- 1 file changed, 30 insertions(+), 80 deletions(-) diff --git a/backends/cadence/aot/tests/test_simplify_ops_passes.py b/backends/cadence/aot/tests/test_simplify_ops_passes.py index 00229757764..195c0ff00ab 100644 --- a/backends/cadence/aot/tests/test_simplify_ops_passes.py +++ b/backends/cadence/aot/tests/test_simplify_ops_passes.py @@ -12,7 +12,6 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch -from executorch.backends.cadence.aot.compiler import export_to_edge from executorch.backends.cadence.aot.graph_builder import single_op_builder from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.simplify_ops import ( @@ -40,82 +39,47 @@ def test_simplify_slice_scatter_op( end: Optional[int] = None, step: int = 1, ): - class SliceScatter(torch.nn.Module): - def __init__( - self, dim: int, start: Optional[int], end: Optional[int], step: int - ): - super().__init__() - self.dim = dim - self.start = start - self.end = end - self.step = step - - def forward(self, x: torch.Tensor, y: torch.Tensor): - return torch.slice_scatter( - x, y, self.dim, self.start, self.end, self.step - ) - - model = SliceScatter(dim, start, end, step) - x = torch.randn(in_shape) - y = torch.randn(src_shape) - graph_module = export_to_edge(model, (x, y)).exported_program().graph_module - - p = SimplifySliceOpPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0 + x = torch.randn(*in_shape) + y = torch.randn(*src_shape) + gm = single_op_builder( + placeholders=(x, y), + op=exir_ops.edge.aten.slice_scatter.default, + args=(x, y, dim, start, end, step), ) + p = SimplifySliceOpPass() + gm = cast(PassResult, p(gm)).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_scatter.default), 0) @parameterized.expand( [ - [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], + [(3, 16, 5), 1, 15, 3, 3], ] ) @torch.no_grad() def test_simplify_slice_op( self, in_shape: Tuple[int], - src_shape: Tuple[int], dim: int, start: Optional[int] = None, end: Optional[int] = None, step: int = 1, ): - class SliceCopy(torch.nn.Module): - def __init__( - self, dim: int, start: Optional[int], end: Optional[int], step: int - ): - super().__init__() - self.dim = dim - self.start = start - self.end = end - self.step = step - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.slice_copy( - x, dim=self.dim, start=self.start, end=self.end, step=self.step - ) - - # Create a model with single slice copy op. - model = SliceCopy(dim, start, end, step) - x = torch.randn(in_shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + x = torch.randn(*in_shape) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.slice_copy.Tensor, + args=( + x, + dim, + start, + end, + step, + ), ) - p = SimplifySliceOpPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1 - ) + gm = cast(PassResult, p(gm)).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 0) + self.assertEqual(count_node(gm, exir_ops.edge.aten.full.default), 1) def test_simplify_slice_op_args(self) -> None: x = torch.rand(4, 5) @@ -125,24 +89,10 @@ def test_simplify_slice_op_args(self) -> None: args=(x, 1), kwargs={"end": 3}, ) - self.assertEqual( - [ - (n.args[1:], n.kwargs) - for n in gm.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor - ) - ], - [((1,), {"end": 3})], - ) - + original_slice_copy = list(gm.graph.nodes)[1] + self.assertEqual(original_slice_copy.args[1:], (1,)) + self.assertEqual(original_slice_copy.kwargs, {"end": 3}) gm = BindOptionalArgsPass().call(gm).graph_module - - self.assertEqual( - [ - (n.args[1:], n.kwargs) - for n in gm.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor - ) - ], - [((1, None, 3, 1), {})], - ) + modified_slice_copy = list(gm.graph.nodes)[1] + self.assertEqual(modified_slice_copy.args[1:], (1, None, 3, 1)) + self.assertEqual(modified_slice_copy.kwargs, {})