Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@impl(m, "quantize_per_tensor")
def quantize_per_tensor(
input: torch.Tensor,
input_tensor: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
Expand All @@ -35,10 +35,10 @@ def quantize_per_tensor(
Quantizes a floating-point tensor to an integral tensor.

Args:
- input (Tensor): input tensor
- scale (float): Quantization scale. Derived from the ratio
- input_tensor (Tensor): input tensor
- scale (float): Inverse of quantization scale. Derived from the ratio
between the min/max of the floating-point tensor and the
min/max of the quantized range.
min/max of the quantized range, and then inverted.
- zero_point (int): The point which represents 0 in the quantized
range. For example, consider the floating point range [-1., 2.] and
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
Expand All @@ -61,7 +61,12 @@ def quantize_per_tensor(
raise ValueError(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
)
return torch.round(input / scale + zero_point).to(dtype)

dequantized = torch.round(input_tensor * scale + zero_point).to(dtype)
return torch.max(
torch.min(dequantized, torch.tensor(quant_max)),
torch.tensor(quant_min),
)


@impl(m, "dequantize_per_tensor")
Expand Down Expand Up @@ -173,9 +178,16 @@ def quantized_add(
dequant_X = X_scale * (X - X_zero_point)
dequant_Y = Y_scale * (Y - Y_zero_point)

out_scale_inv = 1 / out_scale

# q_min/q_max are unused args
return quantize_per_tensor(
dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype
dequant_X + dequant_Y,
out_scale_inv,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
)


Expand Down Expand Up @@ -206,6 +218,7 @@ def quantized_linear(
- offset (Tensor): Unused
"""
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
out_scale_inv = 1 / out_scale

N, K = weight.shape

Expand All @@ -223,7 +236,12 @@ def quantized_linear(
src - in_zero_point, weight - weight_zero_point, bias
)
return quantize_per_tensor(
out, out_scale, out_zero_point, -128, 127, dtype
out,
out_scale_inv,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
dtype,
).reshape(*leading_dims, N)


Expand Down
5 changes: 3 additions & 2 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def test_quantize_per_tensor(
) -> None:
input_tensor = torch.tensor([input_value])
scale = (f_max - f_min) / (q_max - q_min)
zero_point = round(-f_min / scale) + q_min
inv_scale = 1.0 / scale
zero_point = round(-f_min * inv_scale) + q_min
expected_output = torch.tensor([expected_value], dtype=target_dtype)

output = quantize_per_tensor(
input_tensor, scale, zero_point, q_min, q_max, target_dtype
input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype
)

self.assertEqual(
Expand Down
Loading