From 175c8e61e77581bc7204ebbdf49be244b370ab5c Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Mon, 25 Aug 2025 21:19:40 -0700 Subject: [PATCH] Disable mm + add -> addmm fusion if added tensor rank >2 (#13632) Summary: Addmm meta kernel allows the added tensor rank to be >2 but the implementation does not. This diff disables fusion of mm + add in such cases. Reviewed By: zonglinpeng Differential Revision: D80906791 --- backends/cadence/aot/fuse_ops.py | 7 +++-- .../aot/tests/test_fusion_ops_passes.py | 26 ++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 16d4dbde32b..dbd19e1d3af 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -72,11 +72,13 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule): fuse it with mm. """ graph = graph_module.graph - for node in graph.nodes: + for node in graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.mm.default + ): # We want to discover a chain of mm -> add, or mm -> view -> add. # Only proceed if the current node is an mm node, and has only one # user/successor. - if node.target != exir_ops.edge.aten.mm.default or len(node.users) != 1: + if len(node.users) != 1: continue # Our addmm implementation computes (mat1 * mat2 + bias). So the @@ -128,6 +130,7 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule): mm_arg_shape is None or bias_arg_shape is None or not broadcastable(mm_arg_shape, bias_arg_shape) + or len(bias_arg_shape) > 2 ): continue diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 556c227b38d..d160a02721a 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -40,7 +40,29 @@ def check_op_counts( self.assertTrue(op_counts_match(graph_module, expected_op_counts)) -class TestFusionPasses(TestFusionPassesBase): +class TestFuseMMWithAddPass(TestFusionPassesBase): + def test_no_fuse_for_3d_bias(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 3, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(3, 5, dtype=torch.float32)) + z = builder.placeholder("z", torch.randn(1, 4, 5, dtype=torch.float32)) + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, y), + ) + output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) + builder.output([output]) + original_graph = builder.get_graph_module() + + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module + converted_graph.graph.eliminate_dead_code() + self.assertEqual( + count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0 + ) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) + self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1) + def test_fuse_mm_with_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) @@ -176,6 +198,8 @@ def test_keep_mm_add_with_multiple_users(self) -> None: self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3) + +class TestFusionPasses(TestFusionPassesBase): def test_permute_transpose_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))