diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index d7be2be737c..d331ebc80d5 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -46,13 +46,20 @@ def define_node( input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) - if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0: + if ( + input_dtype + not in [ + map_dtype(torch.int8, self.tosa_spec), + map_dtype(torch.int16, self.tosa_spec), + ] + and input_zp != 0 + ): raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" + f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" ) - if output_dtype != torch.int8 and output_zp != 0: + if output_dtype not in [torch.int8, torch.int16] and output_zp != 0: raise ValueError( - f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" + f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) build_rescale(