From b8320ed2ee7add86507fb1a803c8d7cc88ef266c Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:28 -0700 Subject: [PATCH 1/6] Add 16A8W support and test for sigmoid operation Pull Request resolved: https://github.com/pytorch/executorch/pull/13796 Add 16A8W quantization support and test for the sigmoid operation in ExecutorTorch ARM backend. This follows the pattern established for linear and mul operations, extending int16 support to sigmoid operations. Changes: - Add INT16 dtype validation support in op_sigmoid.py - Add test_sigmoid_tensor_16a8w_tosa_INT test function - Enable test_sigmoid.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986667 @exported-using-ghexport Differential Revision: [D80510729](https://our.internmc.facebook.com/intern/diff/D80510729/) --- backends/arm/test/ops/test_sigmoid.py | 111 +++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/ops/test_sigmoid.py b/backends/arm/test/ops/test_sigmoid.py index a29bbc84782..aac2ee1c9b1 100644 --- a/backends/arm/test/ops/test_sigmoid.py +++ b/backends/arm/test/ops/test_sigmoid.py @@ -8,8 +8,13 @@ from typing import Tuple +import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -17,6 +22,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.sigmoid.default" # Used for checking that we do not have softmax in the graph after decompose exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default" @@ -253,3 +260,105 @@ def test_sigmoid_vgf_INT_add_3(): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974" +) +def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor): + """Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Sigmoid(), + (test_data(),), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sigmoid_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" +) +def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor): + """Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Sigmoid(), + (test_data(),), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sigmoid_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations" +) +def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor): + """Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Sigmoid(), + (test_data(),), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sigmoid_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() From 8b65cf092e6b4272533484573354f73001d53430 Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:31 -0700 Subject: [PATCH 2/6] Add 16A8W support and test for tanh operation Pull Request resolved: https://github.com/pytorch/executorch/pull/13797 Add 16A8W quantization support and test for the tanh operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, and sigmoid operations, extending int16 support to tanh operations. Changes: - Add INT16 dtype validation support in op_tanh.py - Add test_tanh_tensor_16a8w_tosa_INT test function - Enable test_tanh.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986671 @exported-using-ghexport Differential Revision: [D80510815](https://our.internmc.facebook.com/intern/diff/D80510815/) --- backends/arm/test/ops/test_tanh.py | 111 ++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/ops/test_tanh.py b/backends/arm/test/ops/test_tanh.py index 098d878addc..0e74618fd2f 100644 --- a/backends/arm/test/ops/test_tanh.py +++ b/backends/arm/test/ops/test_tanh.py @@ -6,9 +6,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -16,6 +21,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.tanh.default" input_t1 = Tuple[torch.Tensor] # Input x @@ -105,3 +112,105 @@ def test_tanh_vgf_INT(test_data: Tuple): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="missing int16 tanh ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13975" +) +def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor): + """Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Tanh(), + (test_data(),), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_tanh_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations" +) +def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor): + """Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Tanh(), + (test_data(),), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_tanh_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations" +) +def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor): + """Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Tanh(), + (test_data(),), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_tanh_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() From 92ce104b709e9fe4cba11624a9daebff2fa0d3e6 Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:37 -0700 Subject: [PATCH 3/6] Add 16A8W support and test for slice operation Pull Request resolved: https://github.com/pytorch/executorch/pull/13798 Add 16A8W quantization support and test for the slice operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, and tanh operations, extending int16 support to slice operations. Changes: - Add INT16 dtype validation support in op_slice.py - Add test_slice_tensor_16a8w_tosa_INT test function - Enable test_slice.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986668 @exported-using-ghexport Differential Revision: [D80511095](https://our.internmc.facebook.com/intern/diff/D80511095/) --- backends/arm/operators/op_slice.py | 2 +- backends/arm/test/ops/test_slice.py | 111 +++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index aad4599a4b5..12d38060aa6 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -57,7 +57,7 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 915aec2e522..eafeb04320e 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -7,9 +7,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -18,6 +23,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.slice.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_slice_copy" @@ -119,3 +126,105 @@ def test_slice_tensor_vgf_INT(test_data: torch.Tensor): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_slice_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="missing int16 slice ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13976" +) +def test_slice_tensor_16a8w_tosa_INT(test_data: torch.Tensor): + """Test slice operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Slice(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_slice_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations" +) +def test_slice_tensor_16a8w_u55_INT16(test_data: torch.Tensor): + """Test slice operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Slice(), + test_data(), + aten_ops=[], + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_slice_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 slice operations" +) +def test_slice_tensor_16a8w_u85_INT16(test_data: torch.Tensor): + """Test slice operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Slice(), + test_data(), + aten_ops=[], + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_slice_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() From 8ff1bf1214b3c86ceec93f10dc5089aa9890ecce Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:38 -0700 Subject: [PATCH 4/6] Add 16A8W support for view and transpose operations Pull Request resolved: https://github.com/pytorch/executorch/pull/13799 Add 16A8W quantization support for view and transpose operations in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, and slice operations, extending int16 support to view and transpose operations. Changes: - Add INT16 dtype validation support in op_transpose.py - Add test_view_tensor_16a8w_tosa_INT test function - Enable test_view.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986670 @exported-using-ghexport Differential Revision: [D80511313](https://our.internmc.facebook.com/intern/diff/D80511313/) --- backends/arm/test/ops/test_view.py | 114 ++++++++++++++++++++++++++++- backends/arm/test/targets.bzl | 1 + 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 0aa6f9a0245..fb0ba54436e 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -9,9 +9,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -20,6 +25,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.view.default" @@ -147,3 +154,108 @@ def test_view_u85_INT(test_data: Tuple): exir_ops=[], ) pipeline.run() + + +def get_symmetric_a16w8_view_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", View.needs_transpose_tests) +@pytest.mark.xfail( + reason="missing int16 view ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13977" +) +def test_view_16a8w_tosa_INT(test_data: Tuple): + """Test view operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + test_tensor, new_shape = test_data() + + pipeline = TosaPipelineINT[input_t1]( + View(new_shape), + (test_tensor,), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_view_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", View.needs_transpose_tests) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 view operations" +) +def test_view_16a8w_u55_INT16(test_data: Tuple): + """Test view operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + test_tensor, new_shape = test_data() + + pipeline = EthosU55PipelineINT[input_t1]( + View(new_shape), + (test_tensor,), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_view_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", View.needs_transpose_tests) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 view operations" +) +def test_view_16a8w_u85_INT16(test_data: Tuple): + """Test view operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + test_tensor, new_shape = test_data() + + pipeline = EthosU85PipelineINT[input_t1]( + View(new_shape), + (test_tensor,), + aten_op, + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_view_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index b438e556cca..5714039d134 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -20,6 +20,7 @@ def define_arm_tests(): "ops/test_slice.py", "ops/test_sigmoid.py", "ops/test_tanh.py", + "ops/test_view.py", "ops/test_cos.py", ] From 05b315539fbfe0b3b62cab198f07d02d3e050c7c Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:40 -0700 Subject: [PATCH 5/6] Add 16A8W support and test for cat operation Pull Request resolved: https://github.com/pytorch/executorch/pull/13800 Add 16A8W quantization support and test for the cat operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, and view/transpose operations, extending int16 support to cat operations. Changes: - Add test_cat_tensor_16a8w_tosa_INT test function - Enable test_cat.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986669 @exported-using-ghexport Differential Revision: [D80511455](https://our.internmc.facebook.com/intern/diff/D80511455/) --- backends/arm/test/ops/test_cat.py | 111 +++++++++++++++++++++++++++++- backends/arm/test/targets.bzl | 1 + 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index 55578aa15c6..84ecd8641b5 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -8,8 +8,13 @@ from typing import Tuple +import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -18,6 +23,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize input_t1 = Tuple[torch.Tensor] # Input x @@ -151,3 +158,105 @@ def test_cat_vgf_INT(test_data: Tuple): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", Cat.test_parameters) +@pytest.mark.xfail( + reason="missing int16 cat ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13978" +) +def test_cat_16a8w_tosa_INT(test_data: Tuple): + """Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Cat(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_cat_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", Cat.test_parameters) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations" +) +def test_cat_16a8w_u55_INT16(test_data: Tuple): + """Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Cat(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_cat_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", Cat.test_parameters) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations" +) +def test_cat_16a8w_u85_INT16(test_data: Tuple): + """Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Cat(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_cat_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 5714039d134..68223eff3ee 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -15,6 +15,7 @@ def define_arm_tests(): test_files += [ "ops/test_add.py", "ops/test_avg_pool2d.py", + "ops/test_cat.py", "ops/test_linear.py", "ops/test_mul.py", "ops/test_slice.py", From 64191d439e03bdadbaed817001d025e17d1d829a Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Wed, 10 Sep 2025 22:06:44 -0700 Subject: [PATCH 6/6] Add 16A8W FCNode support with BMM dependency fix Pull Request resolved: https://github.com/pytorch/executorch/pull/13801 Add 16A8W quantization support for FCNode operations with BMM dependency fix in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, and cat operations, extending int16 support to FCNode operations. Changes: - Add INT16 dtype validation support in op_bmm.py - Add test_addmm_tensor_16a8w_tosa_INT test function - Enable test_addmm.py in test targets configuration - Fix BMM dependency for FCNode operations The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986674 @exported-using-ghexport Differential Revision: [D80512504](https://our.internmc.facebook.com/intern/diff/D80512504/) --- backends/arm/operators/op_bmm.py | 2 +- backends/arm/test/ops/test_addmm.py | 111 +++++++++++++++++++++++++++- backends/arm/test/targets.bzl | 1 + 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 81a4df808c3..382386ffa26 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -55,7 +55,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], output.tosa_spec, ) diff --git a/backends/arm/test/ops/test_addmm.py b/backends/arm/test/ops/test_addmm.py index cfe324ab0af..753cb599b2b 100644 --- a/backends/arm/test/ops/test_addmm.py +++ b/backends/arm/test/ops/test_addmm.py @@ -5,9 +5,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -15,6 +20,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.addmm.default" @@ -182,3 +189,105 @@ def test_addmm_vgf_INT(test_data: input_t1): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="missing int16 addmm ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13979" +) +def test_addmm_16a8w_tosa_INT(test_data: input_t1): + """Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Addmm(), + (*test_data,), + aten_op=[], + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_addmm_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations" +) +def test_addmm_16a8w_u55_INT16(test_data: input_t1): + """Test addmm (FC layer) operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Addmm(), + (*test_data,), + aten_ops=[], + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_addmm_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations" +) +def test_addmm_16a8w_u85_INT16(test_data: input_t1): + """Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Addmm(), + (*test_data,), + aten_ops=[], + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_addmm_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 68223eff3ee..a6181cf34ce 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -14,6 +14,7 @@ def define_arm_tests(): # Operators test_files += [ "ops/test_add.py", + "ops/test_addmm.py", "ops/test_avg_pool2d.py", "ops/test_cat.py", "ops/test_linear.py",