From b8320ed2ee7add86507fb1a803c8d7cc88ef266c Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:28 -0700 Subject: [PATCH] Add 16A8W support and test for sigmoid operation Pull Request resolved: https://github.com/pytorch/executorch/pull/13796 Add 16A8W quantization support and test for the sigmoid operation in ExecutorTorch ARM backend. This follows the pattern established for linear and mul operations, extending int16 support to sigmoid operations. Changes: - Add INT16 dtype validation support in op_sigmoid.py - Add test_sigmoid_tensor_16a8w_tosa_INT test function - Enable test_sigmoid.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986667 @exported-using-ghexport Differential Revision: [D80510729](https://our.internmc.facebook.com/intern/diff/D80510729/) --- backends/arm/test/ops/test_sigmoid.py | 111 +++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index a29bbc84782..aac2ee1c9b1 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -8,8 +8,13 @@ from typing import Tuple +import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -17,6 +22,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.sigmoid.default" # Used for checking that we do not have softmax in the graph after decompose exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default" @@ -253,3 +260,105 @@ def test_sigmoid_vgf_INT_add_3(): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974" +) +def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor): + """Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Sigmoid(), + (test_data(),), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sigmoid_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" +) +def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): + """Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Sigmoid(), + (test_data(),), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sigmoid_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" +) +def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor): + """Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Sigmoid(), + (test_data(),), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sigmoid_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run()