diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index aed62089cea..0ec9d9c581b 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -235,36 +235,28 @@ def test_remove_zero_arg_cat(self): ) def test_remove_clone(self): - class Clone(torch.nn.Module): - def forward(self, x, y): - t1 = x.clone() - t2 = y.clone() - return t1 + t2 - - x = torch.ones(3, 5) - y = torch.ones(3, 5) - graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module - new_graph_module = RemoveCloneOpPass()(graph_module).graph_module - new_graph_module.graph.eliminate_dead_code() - # Assert that t1 and t2 are optimized away - self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0) + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32)) + clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,)) + builder.output([clone]) + original = builder.get_graph_module() + graph_after_passes = RemoveCloneOpPass()(original).graph_module + self.assertEqual( + count_node(graph_after_passes, torch.ops.aten.clone.default), 0 + ) def test_remove_contiguous(self): - class Contiguous(torch.nn.Module): - def forward(self, x, y): - t1 = x.contiguous() - t2 = y.contiguous() - return t1 + t2 - - x = torch.ones(3, 5) - y = torch.ones(3, 5) - graph_module = ( - export_to_edge(Contiguous(), (x, y)).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32)) + contiguous = builder.call_operator( + op=exir_ops.edge.aten.contiguous.default, args=(x,) + ) + builder.output([contiguous]) + original = builder.get_graph_module() + graph_after_passes = RemoveContiguousOpPass()(original).graph_module + self.assertEqual( + count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0 ) - new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module - new_graph_module.graph.eliminate_dead_code() - # Assert that t1 and t2 are optimized away - self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0) @parameterized.expand( [ @@ -274,119 +266,129 @@ def forward(self, x, y): ) @torch.no_grad() def test_remove_nop_view(self, shape, new_shape): - class View(torch.nn.Module): - def __init__(self, new_shape): - super().__init__() - self.new_shape = new_shape - - def forward(self, x: torch.Tensor): - return x.view(self.new_shape) - - model = View(new_shape) - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - p = RemoveNopSliceOrViewOpPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - graph_after_passes.graph.eliminate_dead_code() - # Assert that view op was removed + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, new_shape) + ) + builder.output([view]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopSliceOrViewOpPass()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0 ) def test_remove_nop_slice(self): - class Slice(torch.nn.Module): - def forward(self, x): - return torch.slice_copy(x, dim=0, start=0, step=1) - - x = torch.ones(3, 5) - model = Slice() - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - p = RemoveNopSliceOrViewOpPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - graph_after_passes.graph.eliminate_dead_code() - # Assert that slice op was removed + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) + slice_ = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=( + x, + 0, # dim + 0, # start + 3, # end + ), + ) + builder.output([slice_]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopSliceOrViewOpPass()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 ) - def test_remove_nop_select(self): - class SelectFeasible1(torch.nn.Module): - def forward(self, x): - y = x.select(0, 0) - z = y.view([1, 5, 6]) - return z - - x = torch.ones(1, 5, 6) - graph_module = ( - export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module + def test_remove_nop_select_before_view(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) + select = builder.call_operator( + op=exir_ops.edge.aten.select_copy.int, + args=( + x, + 0, # dim + 0, # index + ), ) - self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(select, [1, 5, 6]), # new shape ) - graph_module = RemoveNopSelectOpPass()(graph_module).graph_module - # Assert that select op was removed + builder.output([view]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopSelectOpPass()(original) + ).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - class SelectFeasible2(torch.nn.Module): - def forward(self, x, y): - x = x.select(0, 0) - z = x + y - return z - - x = torch.ones(1, 5, 6) - y = torch.ones(1, 5, 6) - graph_module = ( - export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module - ) - self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + def test_remove_nop_select_before_add(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) + select = builder.call_operator( + op=exir_ops.edge.aten.select_copy.int, + args=( + x, + 0, # dim + 0, # index + ), ) - graph_module = RemoveNopSelectOpPass()(graph_module).graph_module - # Assert that select op was removed + add = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(select, y)) + builder.output([add]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopSelectOpPass()(original) + ).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - class SelectFeasible3(torch.nn.Module): - def forward(self, x, y): - x = x.select(0, 0) - z = x * y - return z - - x = torch.ones(1, 5, 6) - y = torch.ones(1, 5, 6) - graph_module = ( - export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module - ) - self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + def test_remove_nop_select_before_mul(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) + select = builder.call_operator( + op=exir_ops.edge.aten.select_copy.int, + args=( + x, + 0, # dim + 0, # index + ), ) - graph_module = RemoveNopSelectOpPass()(graph_module).graph_module - # Assert that select op was removed + mul = builder.call_operator(op=exir_ops.edge.aten.mul.Tensor, args=(select, y)) + builder.output([mul]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopSelectOpPass()(original) + ).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - class SelectFeasible4(torch.nn.Module): - def forward(self, x, y): - x = x.select(0, 0) - z = x / y - return z - - x = torch.ones(1, 5, 6) - y = torch.ones(1, 5, 6) - graph_module = ( - export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module - ) - self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 + def test_remove_nop_select_before_div(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) + select = builder.call_operator( + op=exir_ops.edge.aten.select_copy.int, + args=( + x, + 0, # dim + 0, # index + ), ) - graph_module = RemoveNopSelectOpPass()(graph_module).graph_module - # Assert that select op was removed + div = builder.call_operator(op=exir_ops.edge.aten.div.Tensor, args=(select, y)) + builder.output([div]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopSelectOpPass()(original) + ).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 + count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) def test_remove_nop_quant_dequant(self):