diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 21d1efe6398..7ab1959cc9f 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -48,7 +48,13 @@ def quantize_per_tensor( is already provided. - dtype (torch.dtype): The type of the output tensor """ - supported_quant_types = [torch.int8, torch.int16, torch.int32] + supported_quant_types = [ + torch.int8, + torch.int16, + torch.int32, + torch.uint8, + torch.uint16, + ] if dtype not in supported_quant_types: raise ValueError( f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}" @@ -112,6 +118,65 @@ def dequantize_per_tensor( return (input_tensor - zero_point).to(dtype) * scale +@impl(m, "quantized_add") +def quantized_add( + X: torch.Tensor, + X_scale: torch.Tensor, + X_zero_point: torch.Tensor, + Y: torch.Tensor, + Y_scale: torch.Tensor, + Y_zero_point: torch.Tensor, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + """ + Sums up two quantized tensors and returns another quantized tensor. The intuition + is that we want dequant(out) ~= dequant(X) + dequant(Y) + + If we do that math, we get + out_scale(out - out_zero_point) = X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point) + + Rearranging, we get + out = (X_scale(X - X_zero_point) + Y_scale(Y - Y_zero_point)) / out_scale + out_zero_point + + Args: + - X (Tensor): The first operand + - X_scale (Tensor): The ratio between the sizes of X's floating point and quantized + ranges + - X_zero_point (Tensor): The quantized mapping of zero for X + - Y (Tensor): The second operand + - Y_scale (Tensor): The ratio between the sizes of Y's floating point and quantized + ranges + - Y_zero_point (Tensor): The quantized mapping of zero for Y + - out_scale (float): The ratio between the sizes of the output's floating point and + quantized ranges + - out_zero_point (int): The quantized mapping of zero for the output + """ + supported_dtypes = [torch.int8, torch.uint8] + if X.dtype != Y.dtype: + raise ValueError("X and Y dtypes need to match") + + dtype = X.dtype + if dtype not in supported_dtypes: + raise ValueError( + f"X and Y dtypes need to be in {supported_dtypes}. Got {dtype}" + ) + + if dtype == torch.uint8: + X = X.to(torch.int8) + Y = Y.to(torch.int8) + + # TODO(agrebenisan): This should be done in fixed point arithmetic, but to match the quantized_add_out.cpp + # reference implementation, we'll do it in floating point. + dequant_X = X_scale * (X - X_zero_point) + dequant_Y = Y_scale * (Y - Y_zero_point) + + # q_min/q_max are unused args + return quantize_per_tensor( + dequant_X + dequant_Y, out_scale, out_zero_point, -128, 127, dtype + ) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 538798711f5..9e15169c5b1 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -13,6 +13,7 @@ from executorch.backends.cadence.aot.ref_implementations import ( dequantize_per_tensor, quantize_per_tensor, + quantized_add, ) from executorch.backends.cadence.aot.typing_stubs import expand @@ -95,3 +96,45 @@ def test_dequantize_per_tensor( torch.allclose(output, expected_output, rtol=0.001, atol=0.001), f"Values don't match in {name}: got {output}, expected {expected_output}", ) + + @expand( + [ + # Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in + # on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h + ("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8), + ("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8), + ] + ) + def test_quantized_add( + self, + name: str, + X: int, + X_scale: float, + X_zero_point: int, + Y: int, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, + expected_value: int, + dtype: torch.dtype, + ) -> None: + X_tensor = torch.tensor([X], dtype=dtype) + Y_tensor = torch.tensor([Y], dtype=dtype) + expected_output = torch.tensor([expected_value], dtype=dtype) + + output = quantized_add( + X_tensor, + torch.tensor(X_scale), + torch.tensor(X_zero_point, dtype=dtype), + Y_tensor, + torch.tensor(Y_scale), + torch.tensor(Y_zero_point, dtype=dtype), + out_scale, + out_zero_point, + ) + + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match in {name}: got {output}, expected {expected_output}", + )