diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index bda8f9425ea..abbfc41889b 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -24,7 +24,7 @@ @impl(m, "quantize_per_tensor") def quantize_per_tensor( - input: torch.Tensor, + input_tensor: torch.Tensor, scale: float, zero_point: int, quant_min: int, @@ -35,10 +35,10 @@ def quantize_per_tensor( Quantizes a floating-point tensor to an integral tensor. Args: - - input (Tensor): input tensor - - scale (float): Quantization scale. Derived from the ratio + - input_tensor (Tensor): input tensor + - scale (float): Inverse of quantization scale. Derived from the ratio between the min/max of the floating-point tensor and the - min/max of the quantized range. + min/max of the quantized range, and then inverted. - zero_point (int): The point which represents 0 in the quantized range. For example, consider the floating point range [-1., 2.] and quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from @@ -61,7 +61,12 @@ def quantize_per_tensor( raise ValueError( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" ) - return torch.round(input / scale + zero_point).to(dtype) + + dequantized = torch.round(input_tensor * scale + zero_point).to(dtype) + return torch.max( + torch.min(dequantized, torch.tensor(quant_max)), + torch.tensor(quant_min), + ) @impl(m, "dequantize_per_tensor") @@ -173,9 +178,16 @@ def quantized_add( dequant_X = X_scale * (X - X_zero_point) dequant_Y = Y_scale * (Y - Y_zero_point) + out_scale_inv = 1 / out_scale + # q_min/q_max are unused args return quantize_per_tensor( - dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype + dequant_X + dequant_Y, + out_scale_inv, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, ) @@ -206,6 +218,7 @@ def quantized_linear( - offset (Tensor): Unused """ out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + out_scale_inv = 1 / out_scale N, K = weight.shape @@ -223,7 +236,12 @@ def quantized_linear( src - in_zero_point, weight - weight_zero_point, bias ) return quantize_per_tensor( - out, out_scale, out_zero_point, -128, 127, dtype + out, + out_scale_inv, + out_zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + dtype, ).reshape(*leading_dims, N) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 1b02926b3f8..95e1a374463 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -40,11 +40,12 @@ def test_quantize_per_tensor( ) -> None: input_tensor = torch.tensor([input_value]) scale = (f_max - f_min) / (q_max - q_min) - zero_point = round(-f_min / scale) + q_min + inv_scale = 1.0 / scale + zero_point = round(-f_min * inv_scale) + q_min expected_output = torch.tensor([expected_value], dtype=target_dtype) output = quantize_per_tensor( - input_tensor, scale, zero_point, q_min, q_max, target_dtype + input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype ) self.assertEqual(