diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index b5ae756c076..aed62089cea 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -13,8 +13,6 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch import torch.nn as nn -import torch.nn.functional as F -from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.compiler import export_to_edge from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass from executorch.backends.cadence.aot.graph_builder import GraphBuilder @@ -53,16 +51,15 @@ class TestRemoveOpsPasses(unittest.TestCase): ) @torch.no_grad() def test_remove_to_ops(self, shape: Tuple[int]): - class M(torch.nn.Module): - def forward(self, x: torch.Tensor): - return exir_ops.edge.aten.to(x, dtype=torch.float32) - - model = M() - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - p = RemoveToOpsPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + x = builder.call_operator( + op=exir_ops.edge.aten.to.dtype, + args=(x, torch.float32), + ) + builder.output([x]) + original = builder.get_graph_module() + graph_after_passes = cast(PassResult, RemoveToOpsPass()(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.to.dtype), @@ -83,31 +80,24 @@ def forward(self, x: torch.Tensor): ) @torch.no_grad() def test_remove_nop_add_op_pass(self, shape: Tuple[int]): - class FullX(torch.nn.Module): - def forward(self, t: torch.Tensor): - return torch.add(torch.full(shape, 0), t) - - class FullY(torch.nn.Module): - def forward(self, t: torch.Tensor): - return torch.add(t, torch.full(shape, 0)) - - model = FullX() - t = torch.full(shape, 3) - graph_module = export_to_edge(model, (t,)).exported_program().graph_module - - p = RemoveNopAddOpPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), - 0, - ) - - model = FullY() - graph_module = export_to_edge(model, (t,)).exported_program().graph_module - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + zeros = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=(shape, 0) + ) + left_add = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(zeros, x), + ) + right_add = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(left_add, zeros), + ) + builder.output([right_add]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopAddOpPass()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), 0, @@ -122,31 +112,24 @@ def forward(self, t: torch.Tensor): ) @torch.no_grad() def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): - class FullX(torch.nn.Module): - def forward(self, t: torch.Tensor): - return torch.mul(torch.full(shape, 0), t) - - class FullY(torch.nn.Module): - def forward(self, t: torch.Tensor): - return torch.mul(t, torch.full(shape, 0)) - - model = FullX() - t = torch.full(shape, 3) - graph_module = export_to_edge(model, (t,)).exported_program().graph_module - - p = RemoveNopMulOpPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), - 0, - ) - - model = FullY() - graph_module = export_to_edge(model, (t,)).exported_program().graph_module - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + zeros = builder.call_operator( + op=exir_ops.edge.aten.full.default, args=(shape, 0) + ) + left_mul = builder.call_operator( + op=exir_ops.edge.aten.mul.Tensor, + args=(zeros, x), + ) + right_mul = builder.call_operator( + op=exir_ops.edge.aten.mul.Tensor, + args=(left_mul, zeros), + ) + builder.output([right_mul]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopMulOpPass()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), 0, @@ -159,18 +142,16 @@ def forward(self, t: torch.Tensor): ) @torch.no_grad() def test_remove_alias_copy(self, shape: Tuple[int]): - class M(torch.nn.Module): - def forward(self, x: torch.Tensor): - return exir_ops.edge.aten.alias_copy(x) - - model = M() - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - - p = RemoveAliasCopyOpPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + alias = builder.call_operator( + op=exir_ops.edge.aten.alias_copy.default, args=(x,) + ) + builder.output([alias]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveAliasCopyOpPass()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default), 0, @@ -183,19 +164,16 @@ def forward(self, x: torch.Tensor): ) @torch.no_grad() def test_remove_detach_copy(self, shape: Tuple[int]): - # aten::detach is converted to aten::alias_copy after functionalization & decomposition. - class M(torch.nn.Module): - def forward(self, x: torch.Tensor): - return exir_ops.edge.aten.detach_copy(x) - - model = M() - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - - p = RemoveDetachCopyPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + detach = builder.call_operator( + op=exir_ops.edge.aten.detach_copy.default, args=(x,) + ) + builder.output([detach]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveDetachCopyPass()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default), 0, @@ -210,95 +188,51 @@ def forward(self, x: torch.Tensor): def test_remove_zero_sized_constant_pad_nd( self, shape: Tuple[int], padding: Tuple[int] ): - # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. - class Padding(torch.nn.Module): - def __init__(self): - super().__init__() - self.padding = padding - - def forward(self, x: torch.Tensor): - return F.pad(x, self.padding) - - model = Padding() - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - - p = RemoveZeroSizedConstantPadNd() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + pad = builder.call_operator( + op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, padding) + ) + builder.output([pad]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveZeroSizedConstantPadNd()(original) + ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 0, ) def test_remove_expand(self): - class Expand(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.expand_copy(x, [2, 3, 5]) - - x = torch.ones(2, 3, 5) - p = RemoveNopExpandOpPass() - graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module - graph_module = p(graph_module).graph_module - # Assert that expand op is optimized away, since it is a nop + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn([2, 3, 5], dtype=torch.float32)) + expand = builder.call_operator( + op=exir_ops.edge.aten.expand_copy.default, args=(x, [2, 3, 5]) + ) + builder.output([expand]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveNopExpandOpPass()(original) + ).graph_module self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0 + count_node(graph_after_passes, exir_ops.edge.aten.expand_copy.default), 0 ) def test_remove_zero_arg_cat(self): - class Cat(torch.nn.Module): - def forward(self, x, y): - return torch.ops.aten.cat((x, y), 0) - - x = torch.ones(1, 0, 3, 5) - y = torch.ones(2, 0, 3, 5) - graph_module = ( - compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module - ) - # Assert that cat op is optimized away, since it concatenates - # two zero-sized tensors - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) - - def test_remove_single_arg_cat(self): - class Cat(torch.nn.Module): - def forward(self, x, y): - z = torch.ones(0, 5) - # z is an empty tensor, and concatenation of x with z will - # be x. So we can safely eliminate the following cat op. - x1 = torch.ops.aten.cat((x, z)) - x2 = torch.add(x1, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - return torch.add(x2, y1) - - x = torch.ones(3, 5) - y = torch.ones(3, 5) - graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module - new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module - new_graph_module.graph.eliminate_dead_code() - # Assert that x1 is optimized away - self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0) - - def test_remove_zero_sized_cat(self): - class Cat(torch.nn.Module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - def forward(self, tensors): - return torch.cat(tensors, self.dim) - - shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127 - - in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes] - - model = Cat(dim) - graph_module = ( - export_to_edge(model, (in_tensors,)).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn([1, 0, 3, 5], dtype=torch.float32)) + y = builder.placeholder("y", torch.randn([2, 0, 3, 5], dtype=torch.float32)) + concat = builder.call_operator( + op=exir_ops.edge.aten.cat.default, args=([x, y], 0) + ) + builder.output([concat]) + original = builder.get_graph_module() + graph_after_passes = cast( + PassResult, RemoveZeroSizedCatArgsPass()(original) + ).graph_module + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) - new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module - new_graph_module.graph.eliminate_dead_code() - self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0) def test_remove_clone(self): class Clone(torch.nn.Module):