Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools

import operator
from typing import List

import torch
Expand All @@ -22,7 +21,7 @@

class AnnotateDecomposedMatmulPass(ExportPass):
"""
torch.matmul can be decomposed in many ways, for instance:
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
dq -> matmul -> q can become
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
difficult. This helper function find all matmul partitions and annotate its
Expand Down Expand Up @@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
graph_module.graph,
[
torch.matmul,
operator.matmul,
],
None,
)
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _is_matmul_node_supported(
graph_module.graph,
[
torch.matmul,
operator.matmul,
],
None,
)
Expand Down Expand Up @@ -385,7 +386,7 @@ def is_node_supported(
):
source_fn_stack: tuple[typing.Any] = node.meta.get("source_fn_stack", [])
if len(source_fn_stack) > 0:
if source_fn_stack[-1][1] in (torch.matmul,):
if source_fn_stack[-1][1] in (torch.matmul, operator.matmul):
return self._is_matmul_node_supported(submodules, node)

elif node.target in (exir_ops.edge.aten.max_pool2d_with_indices.default,):
Expand Down
149 changes: 149 additions & 0 deletions backends/arm/test/ops/test_at.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineBI,
TosaPipelineMI,
)

aten_op_mm = "torch.ops.aten.matmul.default"
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x


class AtMatMulSingleInput(torch.nn.Module):
test_data_generators = {
"rand_3d": lambda: (torch.rand(2, 5, 5),),
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
}

def forward(self, x: torch.Tensor):
return x @ x


class AtMatMulDoubleInput(torch.nn.Module):
test_data_generators = {
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
}

def forward(self, x: torch.Tensor, y: torch.Tensor):
return x @ y


class AtMatMulMixedPattern1(torch.nn.Module):
test_data_generators = {
"rand_rand_rand_3d": lambda: (
torch.rand(2, 5, 5),
torch.rand(2, 5, 2),
torch.rand(2, 2, 5),
),
"rand_rand_rand_4d": lambda: (
torch.rand(1, 2, 5, 5),
torch.rand(1, 2, 5, 2),
torch.rand(1, 2, 2, 5),
),
}

def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor):
y1 = torch.matmul(x1, x1)
y2 = torch.matmul(x2, x3)
return y1 + y2


class AtMatMulMixedPattern2(torch.nn.Module):
test_data_generators = {
"rand_rand_rand_3d": lambda: (
torch.rand(2, 5, 5),
torch.rand(2, 5, 2),
torch.rand(2, 2, 5),
),
"rand_rand_rand_4d": lambda: (
torch.rand(1, 2, 5, 5),
torch.rand(1, 2, 5, 2),
torch.rand(1, 2, 2, 5),
),
}

def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor):
y1 = torch.matmul(x1, x1)
y2 = torch.matmul(x2, x3)
return y1 @ y2


@common.parametrize("test_data", AtMatMulSingleInput.test_data_generators)
def test_atmatmul_single_input_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators)
def test_atmatmul_double_input_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators)
def test_atmatmul_mixed_pattern1_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
AtMatMulMixedPattern1(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators)
def test_atmatmul_mixed_pattern2_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
AtMatMulMixedPattern2(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", AtMatMulSingleInput.test_data_generators)
def test_atmatmul_single_input_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
AtMatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", AtMatMulDoubleInput.test_data_generators)
def test_atmatmul_double_input_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
AtMatMulDoubleInput(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", AtMatMulMixedPattern1.test_data_generators)
def test_atmatmul_mixed_pattern1_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
AtMatMulMixedPattern1(),
test_data(),
aten_op_mm,
exir_op_mm,
qtol=1,
)
pipeline.run()


@common.parametrize("test_data", AtMatMulMixedPattern2.test_data_generators)
def test_atmatmul_mixed_pattern2_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
AtMatMulMixedPattern2(),
test_data(),
aten_op_mm,
exir_op_mm,
qtol=1,
)
pipeline.run()
Loading