diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index 6fc9e7e5adc..0060bf0ea63 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -118,7 +118,7 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: a RuntimeError instead of returning False. """ - if option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] + if hasattr(pytest, "_test_options") and option in pytest._test_options and pytest._test_options[option]: # type: ignore[attr-defined] return True else: if fail_if_not_enabled: diff --git a/backends/cortex_m/test/ops/__init__.py b/backends/cortex_m/test/ops/__init__.py new file mode 100644 index 00000000000..c8d1c683da3 --- /dev/null +++ b/backends/cortex_m/test/ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/ops/test_add.py b/backends/cortex_m/test/ops/test_add.py new file mode 100644 index 00000000000..10edacb5a11 --- /dev/null +++ b/backends/cortex_m/test/ops/test_add.py @@ -0,0 +1,175 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from executorch.backends.test.suite.operators.test_add import Model, ModelAlpha + + +class CortexMSelfAdd(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return x + x + + +class CortexMScalarAdd(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMTensorAdd(Model): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +class CortexMAlphaAdd(ModelAlpha): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + +test_cases = { + "self_scalar": McuTestCase( + CortexMSelfAdd(), + (10.0,), + ), + "self_rank_1": McuTestCase( + CortexMSelfAdd(), + (torch.linspace(-5, 5, 10),), + ), + "self_rank_2_pos": McuTestCase( + CortexMSelfAdd(), + (torch.linspace(0, 1000, 10).reshape((10, 1)),), + ), + "self_rank_3_neg": McuTestCase( + CortexMSelfAdd(), + (torch.linspace(-100, 0, 8).reshape((2, 2, 2)),), + ), + "self_rank_4_small": McuTestCase( + CortexMSelfAdd(), + (torch.linspace(-0.1, 0.1, 16).reshape(2, 2, 2, 2),), + ), + "self_rank_5": McuTestCase( + CortexMSelfAdd(), + (torch.linspace(-5, 5, 32).reshape(2, 2, 2, 2, 2),), + ), + "scalar_scalar": McuTestCase( + CortexMScalarAdd(), + (-0.5, 1.0), + ), + "tensor_scalar": McuTestCase( + CortexMScalarAdd(), + (torch.ones(2, 2), 1.0), + ), + "scalar_tensor": McuTestCase( + CortexMScalarAdd(), + (1000.0, torch.ones(2, 2)), + ), + "broadcast_1": McuTestCase( + CortexMTensorAdd(), + (torch.ones(1), torch.ones(2, 2, 2, 2)), + ), + "broadcast_2": McuTestCase( + CortexMTensorAdd(), + (torch.ones((2, 1, 1, 1)), torch.ones(1)), + ), + "broadcast_3": McuTestCase( + CortexMTensorAdd(), + ( + torch.linspace(-2, 2, 4).reshape(2, 1, 2, 1), + torch.linspace(-5, 5, 4).reshape(1, 2, 1, 2), + ), + ), + "alpha": McuTestCase( + CortexMAlphaAdd(0.5), + ( + torch.linspace(-10, 10, 20).reshape(4, 5), + torch.linspace(-20, 20, 20).reshape(4, 5), + ), + ), +} + + +dialect_xfails = { + "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "self_rank_1": ("Output 0 does not match reference output", AssertionError), + "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), + "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), + "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), + "self_rank_5": ("Output 0 does not match reference output", AssertionError), + "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "broadcast_3": ("Output 0 does not match reference output", AssertionError), + "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), +} + + +@parametrize("test_case", test_cases, xfails=dialect_xfails) +def test_dialect_add(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + ) + + +implementation_xfails = { + "self_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "self_rank_1": ("Output 0 does not match reference output", AssertionError), + "self_rank_2_pos": ("Output 0 does not match reference output", AssertionError), + "self_rank_3_neg": ("Output 0 does not match reference output", AssertionError), + "self_rank_4_small": ("Output 0 does not match reference output", AssertionError), + "self_rank_5": ("Output 0 does not match reference output", AssertionError), + "scalar_scalar": ("'float' object has no attribute 'fake_mode'", AttributeError), + "tensor_scalar": ("Output 0 does not match reference output", AssertionError), + "scalar_tensor": ("Output 0 does not match reference output", AssertionError), + "broadcast_1": ("Output 0 does not match reference output", AssertionError), + "broadcast_2": ("Output 0 does not match reference output", AssertionError), + "broadcast_3": ("Output 0 does not match reference output", AssertionError), + "alpha": ("Expecting kwargs for aten op IR to be empty", AssertionError), +} + + +@parametrize("test_case", test_cases, xfails=implementation_xfails) +def test_implementation_add(test_case): + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_implementation() diff --git a/backends/cortex_m/test/tester.py b/backends/cortex_m/test/tester.py new file mode 100644 index 00000000000..8af31e58cd7 --- /dev/null +++ b/backends/cortex_m/test/tester.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Any + +import torch + +from backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.backends.arm.test.common import get_u55_compile_spec +from executorch.backends.arm.test.tester.arm_tester import Serialize +from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import ( + QuantizedOpFusionPass, +) + +from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( + ReplaceQuantNodesPass, +) +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import ( + Export, + Quantize, + RunPasses, + StageType, + ToEdgeTransformAndLower, + ToExecutorch, +) +from executorch.backends.xnnpack._passes import XNNPACKPassManager + + +class CortexMQuantize(Quantize): + def __init__(self): + quantizer = XNNPACKQuantizer() + config = get_symmetric_quantization_config() + super().__init__(quantizer, config) + + +class CortexMRunPasses(RunPasses): + def __init__(self): + super().__init__( + XNNPACKPassManager, pass_list=[QuantizedOpFusionPass, ReplaceQuantNodesPass] + ) + + +class CortexMSerialize(Serialize): + def __init__(self): + compile_spec = get_u55_compile_spec() + super().__init__(compile_spec, 1024) + + +cortex_m_stage_classes = { + StageType.EXPORT: Export, + StageType.QUANTIZE: CortexMQuantize, + StageType.RUN_PASSES: CortexMRunPasses, + StageType.SERIALIZE: Serialize, + StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + StageType.TO_EXECUTORCH: ToExecutorch, + StageType.SERIALIZE: CortexMSerialize, +} + + +class CortexMTester(TesterBase): + def __init__(self, module, example_inputs): + super().__init__(module, example_inputs, cortex_m_stage_classes) + + def test_dialect(self, ops_before_transforms, ops_after_transforms, qtol=0): + """ + Test the python dialect op implementation. + """ + self.quantize() + self.export() + self.to_edge_transform_and_lower() + self.check_count(ops_before_transforms) + self.run_passes() + self.check_count(ops_after_transforms) + self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol) + + def test_implementation(self, qtol=0): + """ + Test the optimized op implementation in simulation + """ + self.quantize() + self.export() + self.to_edge_transform_and_lower() + self.run_passes() + self.to_executorch() + self.serialize() + self.run_method_and_compare_outputs(inputs=self.example_inputs, qtol=qtol) + + +@dataclass +class McuTestCase: + model: torch.nn.Module + example_inputs: tuple[Any]