From a84b613b33ae6afe63f3f72678f7c7ef97bad9c4 Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Thu, 28 Aug 2025 23:43:05 -0700 Subject: [PATCH] Add INT16 support to rescale operation Add INT16 support for RequantizeNode rescale operations in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, cat, and FCNode operations, extending int16 support to RequantizeNode rescale operations. Changes: - Add INT16 dtype validation support in op_rescale.py - Enable rescale operations for 16A8W quantization configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. RequantizeNode rescale operations are essential for proper quantization scaling in the 16A8W pipeline. Differential Revision: [D80513725](https://our.internmc.facebook.com/intern/diff/D80513725/) [ghstack-poisoned] --- backends/arm/operators/op_rescale.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 3f86c439995..4c5eaec6b0f 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -46,13 +46,20 @@ def define_node( input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) - if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0: + if ( + input_dtype + not in [ + map_dtype(torch.int8, self.tosa_spec), + map_dtype(torch.int16, self.tosa_spec), + ] + and input_zp != 0 + ): raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" + f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" ) - if output_dtype != torch.int8 and output_zp != 0: + if output_dtype not in [torch.int8, torch.int16] and output_zp != 0: raise ValueError( - f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" + f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) build_rescale(