From 919c918e21f16b314207fadc0c0223817d8f2f59 Mon Sep 17 00:00:00 2001 From: Eli Amesefe Date: Tue, 30 Sep 2025 11:46:25 -0700 Subject: [PATCH] Rescale sub int16 correctly (#14650) Summary: we rescale to int32 using a different common scale(2*max(left scale,right scale)) the rescale(s) -> sub[int32] -> rescale from TOSA is converted to a single sub command stream instruction by Vela/Regor bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: digantdesai Differential Revision: D83437623 --- backends/arm/operators/op_sub.py | 19 +++++- backends/arm/test/ops/test_sub.py | 101 +++++++++++++++++++++++++++++- backends/arm/test/targets.bzl | 1 + 3 files changed, 118 insertions(+), 3 deletions(-) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 9c27fddf68a..5f037dc3d1c 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -50,7 +50,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], output.tosa_spec, ) @@ -59,12 +59,18 @@ def define_node( rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( tosa_graph, inputs, node, self.tosa_spec ) + elif inputs[0].dtype == ts.DType.INT16: + rescaled_inputs, scale_back = ( + tqutils.insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + ) else: # input[0].dtype == ts.DType.INT32 # Non quantized input, natively support by TOSA.SUB rescaled_inputs = inputs - if output.dtype == ts.DType.INT8: + if output.dtype in [ts.DType.INT8, ts.DType.INT16]: broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: @@ -95,6 +101,15 @@ def define_node( compute_rescale=False, tosa_spec=self.tosa_spec, ) # type: ignore[possibly-undefined] + elif output.dtype == ts.DType.INT16: + tqutils.insert_rescale_op_to_int16( + tosa_graph, + sub_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] @register_node_visitor diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index c691506beb2..7a010f0daf2 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -10,8 +10,12 @@ from typing import Tuple import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, EthosU85PipelineINT, @@ -19,6 +23,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.sub.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_sub_Tensor" @@ -242,3 +248,96 @@ def test_sub_tensor_vgf_INT_2(test_data: Tuple[torch.Tensor, torch.Tensor]): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +def get_symmetric_a16w8_sub_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_data", sub_test_data) +def test_sub_tensor_16a8w_tosa_INT(test_data: input_t1): + """Test sub operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t1]( + Sub(), + test_data(), + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sub_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +@common.XfailIfNoCorstone300 +def test_sub_tensor_16a8w_u55_INT16(test_data: input_t1): + """Test sub operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t1]( + Sub(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sub_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +@common.XfailIfNoCorstone320 +def test_sub_tensor_16a8w_u85_INT16(test_data: input_t1): + """Test sub operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t1]( + Sub(), + test_data(), + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_sub_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 093268edef7..5fdd1c3d827 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -22,6 +22,7 @@ def define_arm_tests(): "ops/test_mul.py", "ops/test_slice.py", "ops/test_sigmoid.py", + "ops/test_sub.py", "ops/test_tanh.py", "ops/test_view.py", "ops/test_cos.py",