From 6569ebd3166b76746b4ce8749fd8d066e73185c5 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 17 Oct 2025 08:37:46 -0700 Subject: [PATCH] All type-specific quantize/dequantize (#15165) Summary: As titled. Reviewed By: skrtskrtfb Differential Revision: D84675269 --- backends/cadence/aot/ops_registrations.py | 10 -- backends/cadence/aot/ref_implementations.py | 131 ++++++++++++++++++-- 2 files changed, 124 insertions(+), 17 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index f1c1549c46a..551b23a90be 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -58,20 +58,10 @@ def _validate_ref_impl_exists() -> None: "cadence::_softmax_f32_f32", "cadence::requantize", # We should only support per_tensor variant, should remove "cadence::quantized_softmax.per_tensor", - "cadence::quantize_per_tensor_asym8u", - "cadence::quantize_per_tensor_asym8s", - "cadence::dequantize_per_tensor_asym8u", - "cadence::dequantize_per_tensor_asym32s", - "cadence::dequantize_per_tensor_asym16u", "cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove - "cadence::quantize_per_tensor_asym32s", "cadence::quantized_relu", # We should only support per_tensor variant, should remove "cadence::linalg_svd", "cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove - "cadence::quantize_per_tensor_asym16u", - "cadence::dequantize_per_tensor_asym8s", - "cadence::quantize_per_tensor_asym16s", - "cadence::dequantize_per_tensor_asym16s", "cadence::quantized_softmax", "cadence::quantized_w8a32_gru", "cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index a43d9cabc4f..53cb0845f42 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -43,8 +43,7 @@ def get_registered_ref_implementations() -> set[str]: } -@impl_tracked(m, "quantize_per_tensor") -def quantize_per_tensor( +def quantize_per_tensor_common( input_tensor: torch.Tensor, scale: float, zero_point: int, @@ -93,8 +92,68 @@ def quantize_per_tensor( ) -@impl_tracked(m, "dequantize_per_tensor") -def dequantize_per_tensor( +def quantize_per_tensor_variant( + dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a quantize_per_tensor variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + input_tensor: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + out_dtype: torch.dtype, + ) -> torch.Tensor: + if dtype and out_dtype != dtype: + raise ValueError(f"dtype must be {dtype}. Got {out_dtype}") + + return quantize_per_tensor_common( + input_tensor, + scale, + zero_point, + quant_min, + quant_max, + out_dtype, + ) + + return variant + + return decorator + + +@impl_tracked(m, "quantize_per_tensor") +@quantize_per_tensor_variant() +def quantize_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym8u") +@quantize_per_tensor_variant(torch.uint8) +def quantize_per_tensor_asym8u() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym8s") +@quantize_per_tensor_variant(torch.int8) +def quantize_per_tensor_asym8s() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym16u") +@quantize_per_tensor_variant(torch.uint16) +def quantize_per_tensor_asym16u() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym16s") +@quantize_per_tensor_variant(torch.int16) +def quantize_per_tensor_asym16s() -> torch.Tensor: ... + + +@impl_tracked(m, "quantize_per_tensor_asym32s") +@quantize_per_tensor_variant(torch.int32) +def quantize_per_tensor_asym32s() -> torch.Tensor: ... + + +def dequantize_per_tensor_common( input_tensor: torch.Tensor, scale: float, zero_point: int, @@ -133,14 +192,72 @@ def dequantize_per_tensor( if input_tensor.dtype != dtype: raise ValueError("Input dtype must match dtype") - # Use the reference implementation from torch quantized_decomposed library - # Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior - # difference, since there's no rounding algorithm (just arithmetic). return torch.ops.quantized_decomposed.dequantize_per_tensor( input_tensor, scale, zero_point, quant_min, quant_max, dtype ) +def dequantize_per_tensor_variant( + dtype: torch.dtype | None = None, +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: + """Create a dequantize_per_tensor variant with type checking.""" + + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + def variant( + input_tensor: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + in_dtype: torch.dtype, + ) -> torch.Tensor: + if dtype and in_dtype != dtype: + raise ValueError(f"dtype must be {dtype}. Got {in_dtype}") + + return dequantize_per_tensor_common( + input_tensor, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + ) + + return variant + + return decorator + + +@impl_tracked(m, "dequantize_per_tensor") +@dequantize_per_tensor_variant() +def dequantize_per_tensor() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym8u") +@dequantize_per_tensor_variant(torch.uint8) +def dequantize_per_tensor_asym8u() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym32s") +@dequantize_per_tensor_variant(torch.int32) +def dequantize_per_tensor_asym32s() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym16u") +@dequantize_per_tensor_variant(torch.uint16) +def dequantize_per_tensor_asym16u() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym8s") +@dequantize_per_tensor_variant(torch.int8) +def dequantize_per_tensor_asym8s() -> torch.Tensor: ... + + +@impl_tracked(m, "dequantize_per_tensor_asym16s") +@dequantize_per_tensor_variant(torch.int16) +def dequantize_per_tensor_asym16s() -> torch.Tensor: ... + + @impl_tracked(m, "quantized_add.per_tensor") def quantized_add_per_tensor( X: torch.Tensor,