diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 72c42f0f829..4043a9d7070 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,7 +6,7 @@ # pyre-unsafe import itertools - +import operator from typing import List import torch @@ -22,7 +21,7 @@ class AnnotateDecomposedMatmulPass(ExportPass): """ - torch.matmul can be decomposed in many ways, for instance: + torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance: dq -> matmul -> q can become dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding difficult. This helper function find all matmul partitions and annotate its @@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.graph, [ torch.matmul, + operator.matmul, ], None, ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 26457259a93..89a87b2637a 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -335,6 +335,7 @@ def _is_matmul_node_supported( graph_module.graph, [ torch.matmul, + operator.matmul, ], None, ) @@ -385,7 +386,7 @@ def is_node_supported( ): source_fn_stack: tuple[typing.Any] = node.meta.get("source_fn_stack", []) if len(source_fn_stack) > 0: - if source_fn_stack[-1][1] in (torch.matmul,): + if source_fn_stack[-1][1] in (torch.matmul, operator.matmul): return self._is_matmul_node_supported(submodules, node) elif node.target in (exir_ops.edge.aten.max_pool2d_with_indices.default,): diff --git a/backends/arm/test/ops/test_at.py b/backends/arm/test/ops/test_at.py new file mode 100644 index 00000000000..3d2f5ef7cf2 --- /dev/null +++ b/backends/arm/test/ops/test_at.py @@ -0,0 +1,149 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op_mm = "torch.ops.aten.matmul.default" +exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default" +input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x + + +class AtMatMulSingleInput(torch.nn.Module): + test_data_generators = { + "rand_3d": lambda: (torch.rand(2, 5, 5),), + "rand_4d": lambda: (torch.rand(1, 2, 5, 5),), + } + + def forward(self, x: torch.Tensor): + return x @ x + + +class AtMatMulDoubleInput(torch.nn.Module): + test_data_generators = { + "rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), + "rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), + } + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x @ y + + +class AtMatMulMixedPattern1(torch.nn.Module): + test_data_generators = { + "rand_rand_rand_3d": lambda: ( + torch.rand(2, 5, 5), + torch.rand(2, 5, 2), + torch.rand(2, 2, 5), + ), + "rand_rand_rand_4d": lambda: ( + torch.rand(1, 2, 5, 5), + torch.rand(1, 2, 5, 2), + torch.rand(1, 2, 2, 5), + ), + } + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor): + y1 = torch.matmul(x1, x1) + y2 = torch.matmul(x2, x3) + return y1 + y2 + + +class AtMatMulMixedPattern2(torch.nn.Module): + test_data_generators = { + "rand_rand_rand_3d": lambda: ( + torch.rand(2, 5, 5), + torch.rand(2, 5, 2), + torch.rand(2, 2, 5), + ), + "rand_rand_rand_4d": lambda: ( + torch.rand(1, 2, 5, 5), + torch.rand(1, 2, 5, 2), + torch.rand(1, 2, 2, 5), + ), + } + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor): + y1 = torch.matmul(x1, x1) + y2 = torch.matmul(x2, x3) + return y1 @ y2 + + +@common.parametrize("test_data", AtMatMulSingleInput.test_data_generators) +def test_atmatmul_single_input_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators) +def test_atmatmul_double_input_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators) +def test_atmatmul_mixed_pattern1_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + AtMatMulMixedPattern1(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators) +def test_atmatmul_mixed_pattern2_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + AtMatMulMixedPattern2(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulSingleInput.test_data_generators) +def test_atmatmul_single_input_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators) +def test_atmatmul_double_input_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators) +def test_atmatmul_mixed_pattern1_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + AtMatMulMixedPattern1(), + test_data(), + aten_op_mm, + exir_op_mm, + qtol=1, + ) + pipeline.run() + + +@common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators) +def test_atmatmul_mixed_pattern2_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + AtMatMulMixedPattern2(), + test_data(), + aten_op_mm, + exir_op_mm, + qtol=1, + ) + pipeline.run()