From 5f2d7dc5ba0e529a694b65ddf0e059ad824cf257 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Wed, 28 May 2025 10:32:03 -0700 Subject: [PATCH] Use GraphBuilder in test fusion ops. (#11078) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11078 Reviewed By: hsharma35 Differential Revision: D75183327 --- .../aot/tests/test_fusion_ops_passes.py | 295 ++++++++++-------- 1 file changed, 162 insertions(+), 133 deletions(-) diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index fff2963df29..d01e2e57859 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -14,7 +14,10 @@ import torch from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.fuse_ops import ( + FuseCascadedTransposeOrPermuteOps, + FuseCascadedViewOps, FuseFullThenReshapePass, + FuseMMWithAdd, FuseMulScalarIntoDequantPass, FuseMulTensorIntoDequantPass, FuseQuantDequantToRequantizePass, @@ -39,113 +42,133 @@ def check_op_counts( class TestFusionPasses(TestFusionPassesBase): - def test_addmm_fusion(self): - class AddmmFeasible1(torch.nn.Module): - def forward(self, x, y, z): - t1 = torch.mm(x, y) - return torch.add(t1, z) - - x = torch.randn(3, 5) - y = torch.randn(5, 6) - z = torch.randn(6) - - graph_module = ( - compiler.export_to_cadence(AddmmFeasible1(), (x, y, z)) - .exported_program() - .graph_module + def test_fuse_mm_with_add(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) + z = builder.placeholder("z", torch.randn(6, dtype=torch.float32)) + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, y), + ) + output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) + builder.output([output]) + original_graph = builder.get_graph_module() + converted_graph = FuseMMWithAdd()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() + self.assertEqual( + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 ) - graph_module.graph.eliminate_dead_code() - - # Assert that mm and add were fused to addmm - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) - - class AddmmFeasible2(torch.nn.Module): - def forward(self, x, y, z): - t1 = y.view((8, 6)) - t2 = torch.mm(x, t1) - t3 = t2.view((2, 2, 6)) - return torch.add(t3, z) - - x = torch.randn(4, 8) - y = torch.randn(2, 4, 6) - z = torch.randn(6) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - graph_module = ( - compiler.export_to_cadence(AddmmFeasible2(), (x, y, z)) - .exported_program() - .graph_module + def test_fuse_view_mm_view_add(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32)) + z = builder.placeholder("z", torch.randn(6, dtype=torch.float32)) + y_view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6]) ) - graph_module.graph.eliminate_dead_code() - # Assert that mm and add were fused to addmm - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) - - # Bias is a singleton value, broadcastable to output of mm - class AddmmFeasible3(torch.nn.Module): - def forward(self, x, y): - t1 = torch.mm(x, y) - return torch.add(t1, torch.ones(1)) - - x = torch.randn(3, 5) - y = torch.randn(5, 6) - - graph_module = ( - compiler.export_to_cadence(AddmmFeasible3(), (x, y)) - .exported_program() - .graph_module + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, y_view), + ) + mm_view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6]) ) - graph_module.graph.eliminate_dead_code() - # Assert that mm and add were fused to addmm - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) + output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z) + ) + builder.output([output]) + original_graph = builder.get_graph_module() + converted_graph = FuseMMWithAdd()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() + self.assertEqual( + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 + ) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) + def test_keep_view_mm_view_add(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32)) # Bias is not broadcastable to output of mm - class AddmmInfeasible1(torch.nn.Module): - def forward(self, x, y, z): - t1 = y.view((8, 6)) - t2 = torch.mm(x, t1) - t3 = t2.view((2, 2, 6)) - return torch.add(t3, z) - - x = torch.randn(4, 8) - y = torch.randn(2, 4, 6) - z = torch.randn(2, 2, 1) - - graph_module = ( - compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z)) - .exported_program() - .graph_module + z = builder.placeholder("z", torch.randn(2, 2, 1, dtype=torch.float32)) + y_view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [8, 6]) + ) + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, y_view), ) - graph_module.graph.eliminate_dead_code() + mm_view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(mm, [2, 2, 6]) + ) + output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(mm_view, z) + ) + builder.output([output]) + original_graph = builder.get_graph_module() + converted_graph = FuseMMWithAdd()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() # Assert that mm and add were not fused to addmm, since z cannot be # broadcasted to the out of mm. - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1) - - # The add consuming the output of mm has more than one users. - class AddmmInfeasible2(torch.nn.Module): - def forward(self, x, y, z): - t1 = torch.mm(x, y) - t2 = torch.add(t1, z) - t3 = torch.add(t2, z) - return torch.add(t2, t3) + self.assertEqual( + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0 + ) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1) - x = torch.randn(3, 5) - y = torch.randn(5, 6) - z = torch.randn(6) + def test_fuse_mm_add_with_bias(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, y), + ) + bias = builder.call_operator(op=exir_ops.edge.aten.full.default, args=([1], 1)) + output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(mm, bias) + ) + builder.output([output]) + original_graph = builder.get_graph_module() + converted_graph = FuseMMWithAdd()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() + self.assertEqual( + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 + ) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - graph_module = ( - compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z)) - .exported_program() - .graph_module + def test_keep_mm_add_with_multiple_users(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) + z = builder.placeholder("z", torch.randn(6, dtype=torch.float32)) + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, y), + ) + # The add consuming the output of mm has more than one users. + add1 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) + add2 = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(add1, z)) + output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, args=(add1, add2) ) - graph_module.graph.eliminate_dead_code() + builder.output([output]) + original_graph = builder.get_graph_module() + converted_graph = FuseMMWithAdd()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() # Assert that mm and add were not fused to addmm, since add has multiple # users. - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3) + self.assertEqual( + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0 + ) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3) # TODO(matthiascremon): enable that pass with new flow @torch.no_grad() @@ -184,63 +207,69 @@ def forward(self, x): ) def test_permute_transpose_fusion(self): - class PermuteTranspose(torch.nn.Module): - def forward(self, x): - y = x.permute((0, 2, 4, 1, 3)) - return y.transpose(0, 1) - - x = torch.randn(3, 1, 3, 1, 4) - graph_module = ( - compiler.export_to_cadence(PermuteTranspose(), (x,)) - .exported_program() - .graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3]) + ) + output = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, + args=(permute, 1, 0), ) - graph_module.graph.eliminate_dead_code() + builder.output(output) + original_graph = builder.get_graph_module() + converted_graph = FuseCascadedTransposeOrPermuteOps()( + original_graph + ).graph_module + converted_graph.graph.eliminate_dead_code() # Assert that permute op was fused with transpose op self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1 + count_node(converted_graph, exir_ops.edge.aten.permute_copy.default), 1 ) self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0 + count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0 ) def test_view_fusion(self): - class ViewFusion(torch.nn.Module): - def forward(self, x): - x = x.view([1, 8, 15]) - x = x.view([1, 1, 120]) - return x.view([1, 12, 10]) - - x = torch.randn(8, 5, 3) - graph_module = ( - compiler.export_to_cadence(ViewFusion(), (x,)) - .exported_program() - .graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) + view1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) + ) + view2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(view1, [1, 1, 120]) + ) + output = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(view2, [1, 12, 10]) ) - graph_module.graph.eliminate_dead_code() + builder.output(output) + original_graph = builder.get_graph_module() + converted_graph = FuseCascadedViewOps()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() # Assert that only one view op remains self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1 + count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1 ) def test_view_fusion_branched(self): - class ViewFusion(torch.nn.Module): - def forward(self, x): - y = x.view([1, 8, 15]) - z = y.view([1, 1, 120]) - t = y.view([120, 1, 1]) - return z, t - - x = torch.randn(8, 5, 3) - graph_module = ( - compiler.export_to_cadence(ViewFusion(), (x,)) - .exported_program() - .graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) + y = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) + ) + z = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [1, 1, 120]) ) - graph_module.graph.eliminate_dead_code() + t = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [120, 1, 1]) + ) + builder.output([z, t]) + original_graph = builder.get_graph_module() + converted_graph = FuseCascadedViewOps()(original_graph).graph_module + converted_graph.graph.eliminate_dead_code() # z and t should be fused and y should be eliminated. self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.view_copy.default), 2 + count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2 ) def test_force_quant_dequant_fusion(self):