diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index a571806caf5..a8f0c3fe14d 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -47,10 +47,16 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) + valid_dtypes = [] + if self.tosa_spec.support_integer(): + valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]) + if self.tosa_spec.support_float(): + valid_dtypes.extend([ts.DType.INT32]) + validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + valid_dtypes, output.tosa_spec, ) scale_back = 1.0 @@ -59,7 +65,7 @@ def define_node( tosa_graph, inputs, node, self.tosa_spec ) else: - # input[0].dtype == ts.DType.INT32 + # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 # Non quantized input, natively support by TOSA.ADD rescaled_inputs = inputs @@ -67,7 +73,7 @@ def define_node( broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: - # output.dtype == ts.DType.INT32 + # output.dtype == ts.DType.INT16 or ts.DType.INT32 add_output = output input1, input2 = rescaled_inputs @@ -117,7 +123,7 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output) else: diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 8376df47b39..9d15cea815c 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -10,6 +10,10 @@ import pytest import torch from executorch.backends.arm.quantizer import arm_quantizer +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -235,3 +239,105 @@ def test_add_tensor_vgf_INT(test_data: input_t1): pipeline.run() except FileNotFoundError as e: pytest.skip(f"VKML executor_runner not found - not built - skip {e}") + + +def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", Add.test_data) +@pytest.mark.xfail( + reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13730" +) +def test_add_tensor_16a8w_tosa_INT(test_data: input_t1): + """Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", Add.test_data) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730" +) +def test_add_tensor_16a8w_u55_INT16(test_data: input_t1): + """Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", Add.test_data) +@common.XfailIfNoCorstone320 +@pytest.mark.xfail( + reason="Vela compilation fails with 'Invalid arguments' for int16 add operations. See: https://github.com/pytorch/executorch/issues/13730" +) +def test_add_tensor_16a8w_u85_INT16(test_data: input_t1): + """Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Add(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_add_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index acb27f13798..405f1bbf081 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -13,6 +13,7 @@ def define_arm_tests(): # Operators test_files += [ + "ops/test_add.py", "ops/test_avg_pool2d.py", "ops/test_linear.py", "ops/test_slice.py",