From f7deecc321e6c921a076f056c4897b4813e53f22 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 5 Sep 2025 14:43:50 +0200 Subject: [PATCH 1/3] Cortex_m backend: Add cortex_m tester + test_add Note that tests are currently failing, comparing for example the call to arm_elementwise_add_s8 in op_quantized_add.cpp https://github.com/pytorch/executorch/blob/ab3100715afd21e5f0cee48675d9187152775d86/backends/cortex_m/ops/op_quantized_add.cpp#L88 with the definition in CMSIS-NN https://github.com/ARM-software/CMSIS-NN/blob/88f1982a69c00ed13dd633a63da1009c48abbb4d/Include/arm_nnfunctions.h#L1923 it seems that the args are listed in the wrong order. This will be fixed in a future patch. Minor fixes to get this to work: - Add init file to make test names unique - Update conftest to not crash is_option_enabled for tests running from external folder Signed-off-by: Adrian Lundell Change-Id: I7962fea42994d51f871c8789b0d58b98d60a2739 --- backends/arm/test/conftest.py | 2 +- backends/cortex_m/test/cortex_m_tester.py | 102 +++++++++++++ backends/cortex_m/test/ops/__init__.py | 4 + backends/cortex_m/test/ops/test_add.py | 175 ++++++++++++++++++++++ 4 files changed, 282 insertions(+), 1 deletion(-) create mode 100644 backends/cortex_m/test/cortex_m_tester.py create mode 100644 backends/cortex_m/test/ops/__init__.py create mode 100644 backends/cortex_m/test/ops/test_add.py diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index 6fc9e7e5adc..0060bf0ea63 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -118,7 +118,7 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: a RuntimeError instead of returning False. """ - if option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] + if hasattr(pytest, "_test_options") and option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] return True else: if fail_if_not_enabled: diff --git a/backends/cortex_m/test/cortex_m_tester.py b/backends/cortex_m/test/cortex_m_tester.py new file mode 100644 index 00000000000..2f434d83627 --- /dev/null +++ b/backends/cortex_m/test/cortex_m_tester.py @@ -0,0 +1,102 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Any + +import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test.common import get_u55_compile_spec +from executorch.backends.arm.test.tester.arm_tester import Serialize +from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( + QuantizedOpFusionPass, +) + +from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( + ReplaceQuantNodesPass, +) +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import ( + Export, + Quantize, + RunPasses, + StageType, + ToEdgeTransformAndLower, + ToExecutorch, +) +from executorch.backends.xnnpack._passes import XNNPACKPassManager + + +class CortexMQuantize(Quantize): + def __init__(self): + compile_spec = get_u55_compile_spec() + quantizer = TOSAQuantizer(compile_spec) + config = get_symmetric_quantization_config() + + super().__init__(quantizer, config) + + +class CortexMRunPasses(RunPasses): + def __init__(self): + super().__init__( + XNNPACKPassManager, pass_list=[QuantizedOpFusionPass, ReplaceQuantNodesPass] + ) + + +class CortexMSerialize(Serialize): + def __init__(self): + compile_spec = get_u55_compile_spec() + super().__init__(compile_spec, 1024) + + +cortex_m_stage_classes = { + StageType.EXPORT: Export, + StageType.QUANTIZE: CortexMQuantize, + StageType.RUN_PASSES: CortexMRunPasses, + StageType.SERIALIZE: Serialize, + StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + StageType.TO_EXECUTORCH: ToExecutorch, + StageType.SERIALIZE: CortexMSerialize, +} + + +class CortexMTester(TesterBase): + def __init__(self, module, example_inputs): + super().__init__(module, example_inputs, cortex_m_stage_classes) + + def test_dialect(self, ops_before_transforms, ops_after_transforms, qtol=0): + """ + Test the python dialect op implementation. + """ + self.quantize() + self.export() + self.to_edge_transform_and_lower() + self.check_count(ops_before_transforms) + self.run_passes() + self.check_count(ops_after_transforms) + self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol) + + def test_implementation(self, qtol=0): + """ + Test the optimized op implementation in simulation + """ + self.quantize() + self.export() + self.to_edge_transform_and_lower() + self.run_passes() + self.to_executorch() + self.serialize() + self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol) + + +@dataclass +class McuTestCase: + model: torch.nn.Module + example_inputs: tuple[Any] diff --git a/backends/cortex_m/test/ops/__init__.py b/backends/cortex_m/test/ops/__init__.py new file mode 100644 index 00000000000..c8d1c683da3 --- /dev/null +++ b/backends/cortex_m/test/ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py new file mode 100644 index 00000000000..0db520481ad --- /dev/null +++ b/backends/cortex_m/test/ops/test_add.py @@ -0,0 +1,175 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.cortex_m_tester import CortexMTester, McuTestCase +from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha + + +class SelfAdd(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return x + x + + +class ScalarAdd(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class TensorAdd(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class AlphaAdd(ModelAlpha): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +test_cases = { + "self_scalar": McuTestCase( + SelfAdd(), + (10.0,), + ), + "self_rank_1": McuTestCase( + SelfAdd(), + (torch.linspace(-5, 5, 10),), + ), + "self_rank_2_pos": McuTestCase( + SelfAdd(), + (torch.linspace(0, 1000, 10).reshape((10, 1)),), + ), + "self_rank_3_neg": McuTestCase( + SelfAdd(), + (torch.linspace(-100, 0, 8).reshape((2, 2, 2)),), + ), + "self_rank_4_small": McuTestCase( + SelfAdd(), + (torch.linspace(-0.1, 0.1, 16).reshape(2, 2, 2, 2),), + ), + "self_rank_5": McuTestCase( + SelfAdd(), + (torch.linspace(-5, 5, 32).reshape(2, 2, 2, 2, 2),), + ), + "scalar_scalar": McuTestCase( + ScalarAdd(), + (-0.5, 1.0), + ), + "tensor_scalar": McuTestCase( + ScalarAdd(), + (torch.ones(2, 2), 1.0), + ), + "scalar_tensor": McuTestCase( + ScalarAdd(), + (1000.0, torch.ones(2, 2)), + ), + "broadcast_1": McuTestCase( + TensorAdd(), + (torch.ones(1), torch.ones(2, 2, 2, 2)), + ), + "broadcast_2": McuTestCase( + TensorAdd(), + (torch.ones((2, 1, 1, 1)), torch.ones(1)), + ), + "broadcast_3": McuTestCase( + TensorAdd(), + ( + torch.linspace(-2, 2, 4).reshape(2, 1, 2, 1), + torch.linspace(-5, 5, 4).reshape(1, 2, 1, 2), + ), + ), + "alpha": McuTestCase( + AlphaAdd(0.5), + ( + torch.linspace(-10, 10, 20).reshape(4, 5), + torch.linspace(-20, 20, 20).reshape(4, 5), + ), + ), +} + + +dialect_xfails = { + "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "self_rank_1": ("Output 0 does not match reference output", AssertionError), + "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), + "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), + "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), + "self_rank_5": ("Output 0 does not match reference output", AssertionError), + "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "broadcast_3": ("Output 0 does not match reference output", AssertionError), + "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), +} + + +@parametrize("test_case", test_cases, xfails=dialect_xfails) +def test_dialect_add(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +implementation_xfails = { + "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "self_rank_1": ("Output 0 does not match reference output", AssertionError), + "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), + "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), + "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), + "self_rank_5": ("Output 0 does not match reference output", AssertionError), + "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "tensor_scalar": ("Output 0 does not match reference output", AssertionError), + "scalar_tensor": ("Output 0 does not match reference output", AssertionError), + "broadcast_1": ("Output 0 does not match reference output", AssertionError), + "broadcast_2": ("Output 0 does not match reference output", AssertionError), + "broadcast_3": ("Output 0 does not match reference output", AssertionError), + "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), +} + + +@parametrize("test_case", test_cases, xfails=implementation_xfails) +def test_implementation_add(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() From 69079221e94663a4ea982fc467ec448351360288 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 24 Sep 2025 09:34:25 +0200 Subject: [PATCH 2/3] Fix upstream review comments Signed-off-by: Adrian Lundell Change-Id: Ib98d19bf53f26c1dec103b9e875080b495c6fbaf --- backends/cortex_m/test/ops/test_add.py | 36 +++++++++---------- .../test/{cortex_m_tester.py => tester.py} | 9 ++--- 2 files changed, 20 insertions(+), 25 deletions(-) rename backends/cortex_m/test/{cortex_m_tester.py => tester.py} (92%) diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py index 0db520481ad..10edacb5a11 100644 --- a/backends/cortex_m/test/ops/test_add.py +++ b/backends/cortex_m/test/ops/test_add.py @@ -6,11 +6,11 @@ import torch from executorch.backends.arm.test.common import parametrize -from executorch.backends.cortex_m.test.cortex_m_tester import CortexMTester, McuTestCase +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha -class SelfAdd(torch.nn.Module): +class CortexMSelfAdd(torch.nn.Module): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, @@ -27,7 +27,7 @@ def forward(self, x): return x + x -class ScalarAdd(Model): +class CortexMScalarAdd(Model): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, @@ -41,7 +41,7 @@ class ScalarAdd(Model): } -class TensorAdd(Model): +class CortexMTensorAdd(Model): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, @@ -55,7 +55,7 @@ class TensorAdd(Model): } -class AlphaAdd(ModelAlpha): +class CortexMAlphaAdd(ModelAlpha): ops_before_transforms = { "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, @@ -71,58 +71,58 @@ class AlphaAdd(ModelAlpha): test_cases = { "self_scalar": McuTestCase( - SelfAdd(), + CortexMSelfAdd(), (10.0,), ), "self_rank_1": McuTestCase( - SelfAdd(), + CortexMSelfAdd(), (torch.linspace(-5, 5, 10),), ), "self_rank_2_pos": McuTestCase( - SelfAdd(), + CortexMSelfAdd(), (torch.linspace(0, 1000, 10).reshape((10, 1)),), ), "self_rank_3_neg": McuTestCase( - SelfAdd(), + CortexMSelfAdd(), (torch.linspace(-100, 0, 8).reshape((2, 2, 2)),), ), "self_rank_4_small": McuTestCase( - SelfAdd(), + CortexMSelfAdd(), (torch.linspace(-0.1, 0.1, 16).reshape(2, 2, 2, 2),), ), "self_rank_5": McuTestCase( - SelfAdd(), + CortexMSelfAdd(), (torch.linspace(-5, 5, 32).reshape(2, 2, 2, 2, 2),), ), "scalar_scalar": McuTestCase( - ScalarAdd(), + CortexMScalarAdd(), (-0.5, 1.0), ), "tensor_scalar": McuTestCase( - ScalarAdd(), + CortexMScalarAdd(), (torch.ones(2, 2), 1.0), ), "scalar_tensor": McuTestCase( - ScalarAdd(), + CortexMScalarAdd(), (1000.0, torch.ones(2, 2)), ), "broadcast_1": McuTestCase( - TensorAdd(), + CortexMTensorAdd(), (torch.ones(1), torch.ones(2, 2, 2, 2)), ), "broadcast_2": McuTestCase( - TensorAdd(), + CortexMTensorAdd(), (torch.ones((2, 1, 1, 1)), torch.ones(1)), ), "broadcast_3": McuTestCase( - TensorAdd(), + CortexMTensorAdd(), ( torch.linspace(-2, 2, 4).reshape(2, 1, 2, 1), torch.linspace(-5, 5, 4).reshape(1, 2, 1, 2), ), ), "alpha": McuTestCase( - AlphaAdd(0.5), + CortexMAlphaAdd(0.5), ( torch.linspace(-10, 10, 20).reshape(4, 5), torch.linspace(-20, 20, 20).reshape(4, 5), diff --git a/backends/cortex_m/test/cortex_m_tester.py b/backends/cortex_m/test/tester.py similarity index 92% rename from backends/cortex_m/test/cortex_m_tester.py rename to backends/cortex_m/test/tester.py index 2f434d83627..5921f67ef4a 100644 --- a/backends/cortex_m/test/cortex_m_tester.py +++ b/backends/cortex_m/test/tester.py @@ -9,10 +9,7 @@ import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - get_symmetric_quantization_config, - TOSAQuantizer, -) +from backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config from executorch.backends.arm.test.common import get_u55_compile_spec from executorch.backends.arm.test.tester.arm_tester import Serialize from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( @@ -36,10 +33,8 @@ class CortexMQuantize(Quantize): def __init__(self): - compile_spec = get_u55_compile_spec() - quantizer = TOSAQuantizer(compile_spec) + quantizer = XNNPACKQuantizer() config = get_symmetric_quantization_config() - super().__init__(quantizer, config) From 364ef50c1f4b64996a2daee9f19ab29addf540a9 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 24 Sep 2025 11:10:49 +0200 Subject: [PATCH 3/3] Fix lint issue Signed-off-by: Adrian Lundell --- backends/cortex_m/test/tester.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py index 5921f67ef4a..8af31e58cd7 100644 --- a/backends/cortex_m/test/tester.py +++ b/backends/cortex_m/test/tester.py @@ -9,7 +9,10 @@ import torch -from backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config +from backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) from executorch.backends.arm.test.common import get_u55_compile_spec from executorch.backends.arm.test.tester.arm_tester import Serialize from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import (