From 544f4f65238342bcdec5e0c9d9f9d7a7bb5b857c Mon Sep 17 00:00:00 2001 From: Nitin Jain Date: Sat, 13 Sep 2025 01:41:56 -0700 Subject: [PATCH] Arm backend: Add INT16 support to rescale operation Differential Revision: D80513725 Pull Request resolved: https://github.com/pytorch/executorch/pull/13802 --- backends/arm/operators/op_rescale.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index d7be2be737c..d331ebc80d5 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -46,13 +46,20 @@ def define_node( input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) - if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0: + if ( + input_dtype + not in [ + map_dtype(torch.int8, self.tosa_spec), + map_dtype(torch.int16, self.tosa_spec), + ] + and input_zp != 0 + ): raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" + f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" ) - if output_dtype != torch.int8 and output_zp != 0: + if output_dtype not in [torch.int8, torch.int16] and output_zp != 0: raise ValueError( - f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" + f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) build_rescale(