From e8418d4a99cbe4edaaa27f043e59fb9668ba5416 Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Thu, 28 Aug 2025 23:43:02 -0700 Subject: [PATCH] Add 16A8W FCNode support with BMM dependency fix Add 16A8W quantization support for FCNode operations with BMM dependency fix in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, and cat operations, extending int16 support to FCNode operations. Changes: - Add INT16 dtype validation support in op_bmm.py - Add test_addmm_tensor_16a8w_tosa_INT test function - Enable test_addmm.py in test targets configuration - Fix BMM dependency for FCNode operations The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80512504](https://our.internmc.facebook.com/intern/diff/D80512504/) [ghstack-poisoned] --- backends/arm/operators/op_bmm.py | 5 +- backends/arm/test/ops/test_addmm.py | 108 +++++++++++++++++++++++++++- backends/arm/test/targets.bzl | 1 + 3 files changed, 111 insertions(+), 3 deletions(-) diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index c9bb0b003ee..62c795653aa 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -55,7 +55,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], output.tosa_spec, ) @@ -93,7 +93,8 @@ def define_node( if output.dtype == ts.DType.INT8: output_qparams = get_output_qparams(node)[0] final_output_scale = ( - input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] + input_qparams[0].get_scale_per_tensor() + * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] ) / output_qparams.get_scale_per_tensor() build_rescale( diff --git a/backends/arm/test/ops/test_addmm.py b/backends/arm/test/ops/test_addmm.py index cfe324ab0af..2f99ba14205 100644 --- a/backends/arm/test/ops/test_addmm.py +++ b/backends/arm/test/ops/test_addmm.py @@ -5,9 +5,14 @@ from typing import Tuple +import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -15,6 +20,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.addmm.default" @@ -182,3 +189,102 @@ def test_addmm_vgf_INT(test_data: input_t1): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", test_data_suite) +def test_addmm_16a8w_tosa_INT(test_data: input_t1): + """Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Addmm(), + (*test_data,), + aten_op=[], + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_addmm_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations" +) +def test_addmm_16a8w_u55_INT16(test_data: input_t1): + """Test addmm (FC layer) operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Addmm(), + (*test_data,), + aten_ops=[], + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_addmm_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations" +) +def test_addmm_16a8w_u85_INT16(test_data: input_t1): + """Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Addmm(), + (*test_data,), + aten_ops=[], + exir_ops=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_addmm_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 68223eff3ee..a6181cf34ce 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -14,6 +14,7 @@ def define_arm_tests(): # Operators test_files += [ "ops/test_add.py", + "ops/test_addmm.py", "ops/test_avg_pool2d.py", "ops/test_cat.py", "ops/test_linear.py",