From 20eb5b98b580b7cabd7c0ae22e83d5f3fb3ebe76 Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Wed, 24 Sep 2025 17:15:38 +0100 Subject: [PATCH] Arm backend: Fix torch.matmul() failures for 2D tensor inputs - ConvertMmToBmmPass converts an MM node to BMM nodes, turns input and output tensors from rank-2 to rank-3 via unsqueeze/squeeze, and inserts q-dq before and after BMM node when necessary. - After ConvertMmToBmmPass: x -> q -> dq -> unsqueeze -> q_2 -> dq_2 -> \ bmm -> q_4 -> dq_4 / y -> q_1 -> dq_1 -> unsqueeze -> q_3 -> dq_3 -> - Therefore, if the original matmul was 2D, the bmm already has DQ nodes on its inputs and Q node on its output. If AnnotateDecomposedMatmulPass (#10654) is still applied in this case, it produces illegal sequences such as: x -> q -> unsqueeze -> q_2 (invalid) - Fix by checking whether the BMM is already surrounded by DQ nodes on its inputs and Q nodes on its output. Change-Id: I9949d59b0b4a96fa34a88b0734014567ea6f24cc Signed-off-by: Yufeng Shi Co-authored-by: Oscar Andersson --- backends/arm/_passes/annotate_decomposed_matmul.py | 9 +++++++-- backends/arm/test/ops/test_matmul.py | 7 +++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 81b7b36cc0b..6a1acb8d71d 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -70,7 +70,10 @@ def call(self, graph_module: GraphModule) -> PassResult: node for node in partition.nodes if node.target in matmul_targets ][0] - if quantized_input: + if quantized_input and not all( + input_node.target in DQ_OPS + for input_node in matmul_node.all_input_nodes + ): matmul_args = matmul_node.all_input_nodes for node in matmul_args: # Find the dq-node connected to this mm/bmm arg @@ -96,7 +99,9 @@ def call(self, graph_module: GraphModule) -> PassResult: partition_output = list(partition.output_nodes[0].users)[0] quantized_output = partition_output.target in Q_OPS - if quantized_output: + if quantized_output and not all( + user.target in Q_OPS for user in matmul_node.users + ): with graph_module.graph.inserting_after(matmul_node): # Create q-node after matmul q_node = create_node( diff --git a/backends/arm/test/ops/test_matmul.py b/backends/arm/test/ops/test_matmul.py index d1a21684325..1486dc97f4c 100644 --- a/backends/arm/test/ops/test_matmul.py +++ b/backends/arm/test/ops/test_matmul.py @@ -22,6 +22,7 @@ class MatMul(torch.nn.Module): test_data_generators = { + "rand_rand_2d": lambda: (torch.rand(5, 5), torch.rand(5, 2)), "rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), "rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), } @@ -32,6 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): class MatMulSingleInput(torch.nn.Module): test_data_generators = { + "rand_2d": lambda: (torch.rand(5, 5),), "rand_3d": lambda: (torch.rand(2, 5, 5),), "rand_4d": lambda: (torch.rand(1, 2, 5, 5),), } @@ -42,6 +44,11 @@ def forward(self, x: torch.Tensor): class MatMulCombo(torch.nn.Module): test_data_generators = { + "rand_rand_rand_2d": lambda: ( + torch.rand(5, 5), + torch.rand(5, 2), + torch.rand(2, 5), + ), "rand_rand_rand_3d": lambda: ( torch.rand(2, 5, 5), torch.rand(2, 5, 2),