Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,62 @@ def quantize_per_tensor(
return torch.round(input / scale + zero_point).to(dtype)


@impl(m, "dequantize_per_tensor")
def dequantize_per_tensor(
input_tensor: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Dequantizes an integral tensor to a floating-point tensor.

Args:
- input (Tensor): input tensor
- scale (float): Quantization scale. Derived from the ratio
between the min/max of the floating-point tensor and the
min/max of the quantized range.
- zero_point (int): The point which represents 0 in the quantized
range. For example, consider the floating point range [-1., 2.] and
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
-1. to 2. So, the point that represents 0 in the quantized range should
be 1/3 of the way from [-7, 7]. This ends up being -2 in the integer space.
- quant_min (int): The smallest value in the quantized domain. Unused since scale
is already provided.
- quant_max (int): The largest value in the quantized domain. Unused since scale
is already provided.
- dtype (torch.dtype): The type of the output tensor. Must be a floating point type.
"""
supported_quant_types = [
torch.int8,
torch.int16,
torch.int32,
torch.uint8,
torch.uint16,
]
if input_tensor.dtype not in supported_quant_types:
raise ValueError(f"Input dtype must be one of {supported_quant_types}")
supported_dequant_types = [
torch.float,
torch.float32,
torch.float16,
torch.bfloat16,
]
if dtype not in supported_dequant_types:
raise ValueError(
f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}"
)

# Needed to prevent underflow in cases where the zero_point is larger than
# the quantized value.
if not input_tensor.dtype.is_signed:
input_tensor = input_tensor.to(torch.int32)

return (input_tensor - zero_point).to(dtype) * scale


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
49 changes: 48 additions & 1 deletion backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

import torch

from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor
from executorch.backends.cadence.aot.ref_implementations import (
dequantize_per_tensor,
quantize_per_tensor,
)
from executorch.backends.cadence.aot.typing_stubs import expand


Expand Down Expand Up @@ -48,3 +51,47 @@ def test_quantize_per_tensor(
torch.equal(output, expected_output),
f"Values don't match in {name}: got {output}, expected {expected_output}",
)

@expand(
[
# Signed quantization ranges
("signed_range0_int8", 0, -1.0, 2.0, -7, 7, torch.int8, 0.428),
("signed_range0_int16", 0, -1.0, 2.0, -7, 7, torch.int16, 0.428),
("signed_range0_int32", 0, -1.0, 2.0, -7, 7, torch.int32, 0.428),
("signed_range1_int8", -3, -1.0, 5.0, -6, 7, torch.int8, 0.461),
("signed_range1_int16", -3, -1.0, 5.0, -6, 7, torch.int16, 0.461),
("signed_range1_int32", -3, -1.0, 5.0, -6, 7, torch.int32, 0.461),
# Unsigned quantization ranges
("unsigned_range0_uint8", 3, -1.0, 2.0, 0, 7, torch.uint8, 0.428),
("unsigned_range0_uint16", 3, -1.0, 2.0, 0, 7, torch.uint16, 0.428),
("unsigned_range1_uint8", 4, -1.0, 5.0, 3, 7, torch.uint8, 0.0),
("unsigned_range1_uint16", 4, -1.0, 5.0, 3, 7, torch.uint16, 0.0),
]
)
def test_dequantize_per_tensor(
self,
name: str,
input_value: int,
f_min: float,
f_max: float,
q_min: int,
q_max: int,
input_dtype: torch.dtype,
expected_value: int,
) -> None:
input_tensor = torch.tensor([input_value], dtype=input_dtype)
scale = (f_max - f_min) / (q_max - q_min)
zero_point = round(-f_min / scale) + q_min
expected_output = torch.tensor([expected_value], dtype=torch.float32)

output = dequantize_per_tensor(
input_tensor, scale, zero_point, q_min, q_max, torch.float32
)

self.assertEqual(
output.dtype, expected_output.dtype, f"Dtype mismatch in {name}"
)
self.assertTrue(
torch.allclose(output, expected_output, rtol=0.001, atol=0.001),
f"Values don't match in {name}: got {output}, expected {expected_output}",
)
Loading