From 5d9e1083854850ac8208ba5c44d224eaee67442b Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Thu, 9 Oct 2025 13:09:11 +0100 Subject: [PATCH] Arm backend: Add support for sigmoid and tanh int16x8 --- .../arm/operator_support/ethos_u55_support.py | 4 +- backends/arm/quantizer/arm_quantizer.py | 4 +- backends/arm/test/ops/test_sigmoid.py | 17 +- backends/arm/test/ops/test_sigmoid_16bit.py | 190 ------------------ backends/arm/test/ops/test_tanh.py | 16 +- 5 files changed, 20 insertions(+), 211 deletions(-) delete mode 100644 backends/arm/test/ops/test_sigmoid_16bit.py diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 27ddb95637b..2403cfffa7e 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -114,9 +114,9 @@ def is_node_supported( # noqa: C901 return False if node.target in self.target_ops_i8: - if dtype not in (torch.int8,): + if dtype not in (torch.int8, torch.int16): self.reporter.report_reject( - node, f"Unsupported dtype {dtype} (Supports i8)." + node, f"Unsupported dtype {dtype} (Supports i8, i16)." ) return False diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 3f03d1a3d70..2b0b028c5e4 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -161,6 +161,7 @@ def get_symmetric_a16w8_quantization_config( is_dynamic: bool = False, weight_qmin: int = -127, weight_qmax: int = 127, + epsilon: float = 2**-12, ): """ 16A8W quantization config: 16-bit activations, 8-bit weights. @@ -174,11 +175,12 @@ def get_symmetric_a16w8_quantization_config( is_dynamic: Whether to use dynamic quantization weight_qmin: Minimum quantization value for weights weight_qmax: Maximum quantization value for weights + epsilon: Value used to pad observed [qmin, qmax] before initial zero point and scale calculation Returns: QuantizationConfig with 16-bit activations and 8-bit weights """ - extra_args: Dict[str, Any] = {"eps": 2**-12} + extra_args: Dict[str, Any] = {"eps": epsilon} # Setup observer/fake-quant for 16-bit activations if is_qat: diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index a9b9ef11b48..a9e3802f75b 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -34,6 +34,7 @@ "zeros": lambda: torch.zeros(10, 10, 10, 10), "ones": lambda: torch.ones(10, 10, 10), "rand": lambda: torch.rand(10, 10) - 0.5, + "rand_4d": lambda: torch.rand(1, 1, 5, 10), "randn_pos": lambda: torch.randn(10) + 10, "randn_neg": lambda: torch.randn(10) - 10, "ramp": lambda: torch.arange(-16, 16, 0.2), @@ -269,22 +270,23 @@ def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False): } quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + + # Use a smaller episilon value to not greatly inflate [qmin, qmax] quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=2**-16 + ) ) return Quantize( quantizer, get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization + is_per_channel=per_channel_quantization, epsilon=2**-16 ), ) @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974" -) def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor): """Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -311,7 +313,7 @@ def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 @pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." ) def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): """Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" @@ -337,9 +339,6 @@ def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" -) def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor): """Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py deleted file mode 100644 index 587ba99222a..00000000000 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -import torch -from executorch.backends.arm.quantizer import ( - get_symmetric_quantization_config, - TOSAQuantizer, -) -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU85PipelineINT, - OpNotSupportedPipeline, - TosaPipelineINT, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize -from torchao.quantization.pt2e import HistogramObserver -from torchao.quantization.pt2e.quantizer import QuantizationSpec - - -def _get_16_bit_quant_config(): - int16_spec = QuantizationSpec( - dtype=torch.int16, - observer_or_fake_quant_ctr=HistogramObserver, - qscheme=torch.per_tensor_symmetric, - ) - qconfig = QuantizationConfig( - input_activation=int16_spec, - output_activation=int16_spec, - weight=None, - bias=None, - ) - return qconfig - - -def get_16bit_sigmoid_quantizer(u55_config=False): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string( - "TOSA-1.0+INT+int16" + ("+u55" if u55_config else "") - ), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global(get_symmetric_quantization_config()) - quantizer.set_module_type( - torch.nn.modules.activation.Sigmoid, _get_16_bit_quant_config() - ) - - return Quantize(quantizer, get_symmetric_quantization_config()) - - -input_t = tuple[torch.Tensor] -test_data_suite = { - "ones": lambda: torch.ones(10, 10, 10), - "rand": lambda: torch.rand(10, 10) - 0.5, - "rand_4d": lambda: torch.rand(1, 1, 5, 10), - "randn_pos": lambda: torch.randn(10) + 10, - "randn_neg": lambda: torch.randn(10) - 10, - "ramp": lambda: torch.arange(-16, 16, 0.02), -} - - -class Sigmoid(torch.nn.Module): - aten_op = "torch.ops.aten.sigmoid.default" - exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default" - - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - return self.sigmoid(x) - - -class SigmoidAddSigmoid(torch.nn.Module): - def __init__(self): - super().__init__() - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - return self.sigmoid((self.sigmoid(x) + self.sigmoid(x))) - - -@common.parametrize("test_data", test_data_suite) -def test_sigmoid_tosa_INT(test_data): - pipeline = TosaPipelineINT( - Sigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - qtol=1, - tosa_extensions=["int16"], - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, - xfails={ - "ramp": "AssertionError: Output 0 does not match reference output. MLETORCH-787" - }, - strict=False, -) -def test_sigmoid_tosa_INT_add_sigmoid(test_data): - pipeline = TosaPipelineINT( - SigmoidAddSigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - qtol=1, - tosa_extensions=["int16"], - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, -) -@common.XfailIfNoCorstone300 -def test_sigmoid_u55_INT(test_data): - pipeline = OpNotSupportedPipeline( - Sigmoid(), - (test_data(),), - {Sigmoid.exir_op: 1}, - quantize=True, - u55_subset=True, - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer(True)) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, -) -@common.XfailIfNoCorstone300 -def test_sigmoid_u55_INT_add_sigmoid(test_data): - pipeline = OpNotSupportedPipeline( - SigmoidAddSigmoid(), - (test_data(),), - {Sigmoid.exir_op: 3}, - n_expected_delegates=1, - quantize=True, - u55_subset=True, - tosa_extensions=["int16"], - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer(True)) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -@common.XfailIfNoCorstone320 -def test_sigmoid_u85_INT(test_data): - pipeline = EthosU85PipelineINT( - Sigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() - - -@common.parametrize( - "test_data", - test_data_suite, - xfails={ - "ramp": "AssertionError: Output 0 does not match reference output. MLETORCH-787" - }, -) -@pytest.mark.xfail # MLETORCH-787: Investigate int16-int8 rescaling precision -@common.XfailIfNoCorstone320 -def test_sigmoid_u85_INT_add_sigmoid(test_data): - pipeline = EthosU85PipelineINT( - SigmoidAddSigmoid(), - (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - ) - pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) - pipeline.run() diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 8dc967c01d7..d863d13a5c0 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -121,22 +121,23 @@ def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False): } quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + + # Use a smaller episilon value to not greatly inflate [qmin, qmax] quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization, epsilon=2**-16 + ) ) return Quantize( quantizer, get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization + is_per_channel=per_channel_quantization, epsilon=2**-16 ), ) @common.parametrize("test_data", test_data_suite) -@pytest.mark.xfail( - reason="missing int16 tanh ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13975" -) def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor): """Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)""" per_channel_quantization = False @@ -163,7 +164,7 @@ def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 @pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations" + reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." ) def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor): """Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" @@ -189,9 +190,6 @@ def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations" -) def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor): """Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" per_channel_quantization = False