Skip to content
15 changes: 11 additions & 4 deletions backends/arm/operators/op_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,20 @@ def define_node(
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])

if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0:
if (
input_dtype
not in [
map_dtype(torch.int8, self.tosa_spec),
map_dtype(torch.int16, self.tosa_spec),
]
and input_zp != 0
):
raise ValueError(
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
)
if output_dtype != torch.int8 and output_zp != 0:
if output_dtype not in [torch.int8, torch.int16] and output_zp != 0:
raise ValueError(
f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
)

build_rescale(
Expand Down
Loading