Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
a RuntimeError instead of returning False.
"""

if option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined]
if hasattr(pytest, "_test_options") and option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined]
return True
else:
if fail_if_not_enabled:
Expand Down
4 changes: 4 additions & 0 deletions backends/cortex_m/test/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
175 changes: 175 additions & 0 deletions backends/cortex_m/test/ops/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
from executorch.backends.arm.test.common import parametrize
from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase
from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha


class CortexMSelfAdd(torch.nn.Module):
ops_before_transforms = {
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
}

ops_after_transforms = {
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
}

def forward(self, x):
return x + x


class CortexMScalarAdd(Model):
ops_before_transforms = {
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
}

ops_after_transforms = {
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
}


class CortexMTensorAdd(Model):
ops_before_transforms = {
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
}

ops_after_transforms = {
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
}


class CortexMAlphaAdd(ModelAlpha):
ops_before_transforms = {
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
}

ops_after_transforms = {
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1,
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2,
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
}


test_cases = {
"self_scalar": McuTestCase(
CortexMSelfAdd(),
(10.0,),
),
"self_rank_1": McuTestCase(
CortexMSelfAdd(),
(torch.linspace(-5, 5, 10),),
),
"self_rank_2_pos": McuTestCase(
CortexMSelfAdd(),
(torch.linspace(0, 1000, 10).reshape((10, 1)),),
),
"self_rank_3_neg": McuTestCase(
CortexMSelfAdd(),
(torch.linspace(-100, 0, 8).reshape((2, 2, 2)),),
),
"self_rank_4_small": McuTestCase(
CortexMSelfAdd(),
(torch.linspace(-0.1, 0.1, 16).reshape(2, 2, 2, 2),),
),
"self_rank_5": McuTestCase(
CortexMSelfAdd(),
(torch.linspace(-5, 5, 32).reshape(2, 2, 2, 2, 2),),
),
"scalar_scalar": McuTestCase(
CortexMScalarAdd(),
(-0.5, 1.0),
),
"tensor_scalar": McuTestCase(
CortexMScalarAdd(),
(torch.ones(2, 2), 1.0),
),
"scalar_tensor": McuTestCase(
CortexMScalarAdd(),
(1000.0, torch.ones(2, 2)),
),
"broadcast_1": McuTestCase(
CortexMTensorAdd(),
(torch.ones(1), torch.ones(2, 2, 2, 2)),
),
"broadcast_2": McuTestCase(
CortexMTensorAdd(),
(torch.ones((2, 1, 1, 1)), torch.ones(1)),
),
"broadcast_3": McuTestCase(
CortexMTensorAdd(),
(
torch.linspace(-2, 2, 4).reshape(2, 1, 2, 1),
torch.linspace(-5, 5, 4).reshape(1, 2, 1, 2),
),
),
"alpha": McuTestCase(
CortexMAlphaAdd(0.5),
(
torch.linspace(-10, 10, 20).reshape(4, 5),
torch.linspace(-20, 20, 20).reshape(4, 5),
),
),
}


dialect_xfails = {
"self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError),
"self_rank_1": ("Output 0 does not match reference output", AssertionError),
"self_rank_2_pos": ("Output 0 does not match reference output", AssertionError),
"self_rank_3_neg": ("Output 0 does not match reference output", AssertionError),
"self_rank_4_small": ("Output 0 does not match reference output", AssertionError),
"self_rank_5": ("Output 0 does not match reference output", AssertionError),
"scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError),
"broadcast_3": ("Output 0 does not match reference output", AssertionError),
"alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError),
}


@parametrize("test_case", test_cases, xfails=dialect_xfails)
def test_dialect_add(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_dialect(
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
)


implementation_xfails = {
"self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError),
"self_rank_1": ("Output 0 does not match reference output", AssertionError),
"self_rank_2_pos": ("Output 0 does not match reference output", AssertionError),
"self_rank_3_neg": ("Output 0 does not match reference output", AssertionError),
"self_rank_4_small": ("Output 0 does not match reference output", AssertionError),
"self_rank_5": ("Output 0 does not match reference output", AssertionError),
"scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError),
"tensor_scalar": ("Output 0 does not match reference output", AssertionError),
"scalar_tensor": ("Output 0 does not match reference output", AssertionError),
"broadcast_1": ("Output 0 does not match reference output", AssertionError),
"broadcast_2": ("Output 0 does not match reference output", AssertionError),
"broadcast_3": ("Output 0 does not match reference output", AssertionError),
"alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError),
}


@parametrize("test_case", test_cases, xfails=implementation_xfails)
def test_implementation_add(test_case):
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_implementation()
100 changes: 100 additions & 0 deletions backends/cortex_m/test/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from dataclasses import dataclass
from typing import Any

import torch

from backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.backends.arm.test.common import get_u55_compile_spec
from executorch.backends.arm.test.tester.arm_tester import Serialize
from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import (
QuantizedOpFusionPass,
)

from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
ReplaceQuantNodesPass,
)
from executorch.backends.test.harness import Tester as TesterBase
from executorch.backends.test.harness.stages import (
Export,
Quantize,
RunPasses,
StageType,
ToEdgeTransformAndLower,
ToExecutorch,
)
from executorch.backends.xnnpack._passes import XNNPACKPassManager


class CortexMQuantize(Quantize):
def __init__(self):
quantizer = XNNPACKQuantizer()
config = get_symmetric_quantization_config()
super().__init__(quantizer, config)


class CortexMRunPasses(RunPasses):
def __init__(self):
super().__init__(
XNNPACKPassManager, pass_list=[QuantizedOpFusionPass, ReplaceQuantNodesPass]
)


class CortexMSerialize(Serialize):
def __init__(self):
compile_spec = get_u55_compile_spec()
super().__init__(compile_spec, 1024)


cortex_m_stage_classes = {
StageType.EXPORT: Export,
StageType.QUANTIZE: CortexMQuantize,
StageType.RUN_PASSES: CortexMRunPasses,
StageType.SERIALIZE: Serialize,
StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower,
StageType.TO_EXECUTORCH: ToExecutorch,
StageType.SERIALIZE: CortexMSerialize,
}


class CortexMTester(TesterBase):
def __init__(self, module, example_inputs):
super().__init__(module, example_inputs, cortex_m_stage_classes)

def test_dialect(self, ops_before_transforms, ops_after_transforms, qtol=0):
"""
Test the python dialect op implementation.
"""
self.quantize()
self.export()
self.to_edge_transform_and_lower()
self.check_count(ops_before_transforms)
self.run_passes()
self.check_count(ops_after_transforms)
self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol)

def test_implementation(self, qtol=0):
"""
Test the optimized op implementation in simulation
"""
self.quantize()
self.export()
self.to_edge_transform_and_lower()
self.run_passes()
self.to_executorch()
self.serialize()
self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol)


@dataclass
class McuTestCase:
model: torch.nn.Module
example_inputs: tuple[Any]
Loading