From 2149436469152bee408a80cbb9977d5ad96db213 Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Thu, 28 Aug 2025 21:20:07 -0700 Subject: [PATCH] Add 16A8W support and test for add operation Add 16A8W quantization support and test for the add operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to add operations. Changes: - Add INT16 dtype validation support in op_add.py - Add test_add_tensor_16a8w_tosa_INT test function - Enable test_add.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510463](https://our.internmc.facebook.com/intern/diff/D80510463/) [ghstack-poisoned] --- backends/arm/test/ops/test_add.py | 104 ++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 24fdfbb5457..b99c12c2244 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -334,3 +334,107 @@ def test_add_tensor_16a8w_u85_INT16(test_data: input_t1): ), ) pipeline.run() + + +def get_symmetric_a16w8_add_quantizer(u55_config=False, per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", Add.test_data) +@pytest.mark.xfail( + reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank" +) +def test_add_tensor_16a8w_tosa_INT(test_data: input_t1): + """Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", Add.test_data) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank" +) +def test_add_tensor_16a8w_u55_INT16(test_data: input_t1): + """Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + u55_config=True, per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", Add.test_data) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank" +) +def test_add_tensor_16a8w_u85_INT16(test_data: input_t1): + """Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + u55_config=False, per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run()