diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 7ab1959cc9f..bda8f9425ea 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -6,6 +6,8 @@ # pyre-strict +from typing import Optional + import torch from executorch.exir.scalar_type import ScalarType from torch.library import impl, Library @@ -177,6 +179,54 @@ def quantized_add( ) +@impl(m, "quantized_linear") +def quantized_linear( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: torch.Tensor, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, + out_zero_point: int, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Quantized linear (transposed matmul) operation. + + Args: + - src (Tensor): The activations tensor + - weight (Tensor): The weight tensor + - bias (Tensor): The bias tensor + - in_zero_point (int): The quantized mapping of zero for the input + - weight_zero_point (Tensor): The quantized mapping of zero for the weight + - out_multiplier (Tensor): The multiplier used to scale the output + - out_shift (Tensor): The shift used to scale the output + - out_zero_point (int): The quantized mapping of zero for the output + - offset (Tensor): Unused + """ + out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0]) + + N, K = weight.shape + + leading_dims = src.shape[:-1] + src = src.view(-1, K) + + dtype = src.dtype + supported_dtypes = [torch.int8, torch.uint8, torch.int32] + if dtype not in supported_dtypes: + raise ValueError( + f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}" + ) + + out = torch.nn.functional.linear( + src - in_zero_point, weight - weight_zero_point, bias + ) + return quantize_per_tensor( + out, out_scale, out_zero_point, -128, 127, dtype + ).reshape(*leading_dims, N) + + @impl(m, "requantize") def requantize( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 9e15169c5b1..1b02926b3f8 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -5,15 +5,17 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - +import typing import unittest +import numpy as np import torch from executorch.backends.cadence.aot.ref_implementations import ( dequantize_per_tensor, quantize_per_tensor, quantized_add, + quantized_linear, ) from executorch.backends.cadence.aot.typing_stubs import expand @@ -138,3 +140,102 @@ def test_quantized_add( torch.equal(output, expected_output), f"Values don't match in {name}: got {output}, expected {expected_output}", ) + + @expand( + [ + # Test case 1: 1x2 input, 1x2 weight (1 output feature) + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features + 0, # in_zero_point + torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int8), # out_shift + 0, # out_zero_point + torch.tensor([[-2]], dtype=torch.int8), # expected_output + ), + # Test case 2: 1x3 input, 2x3 weight (2 output features) + ( + torch.Size([1, 3]), # src_shape: 1 sample, 3 input features + torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features + 0, # in_zero_point + torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int8), # out_shift + 0, # out_zero_point + torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output + ), + # Test case 3: Batch case with different dimensions + ( + torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2 + torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features + 0, # in_zero_point + torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point + torch.tensor( + [1073741824], dtype=torch.int32 + ), # out_multiplier (0.5 * 2^31) + torch.tensor([0], dtype=torch.int8), # out_shift + 0, # out_zero_point + torch.tensor( + [[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8 + ), # expected_output + ), + # Test case 4: Non-zero zero points + ( + torch.Size([1, 2]), # src_shape: 1 sample, 2 input features + torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature + 2, # in_zero_point + torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point + torch.tensor( + [268435456], dtype=torch.int32 + ), # out_multiplier (1.0 * 2^31) + torch.tensor([0]), # out_shift + 1, # out_zero_point + torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output + ), + ] + ) + def test_quantized_linear( + self, + src_shape: torch.Size, + weight_shape: torch.Size, + in_zero_point: int, + weight_zero_point: torch.Tensor, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, + out_zero_point: int, + expected_output: torch.Tensor, + ) -> None: + src = ( + torch.arange(np.product(src_shape)) + .reshape(src_shape) + .to(expected_output.dtype) + ) + weight = ( + torch.arange(np.product(weight_shape)) + .reshape(weight_shape) + .to(expected_output.dtype) + ) + bias = torch.arange(weight_shape[0]).to(expected_output.dtype) + output = quantized_linear( + src, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + typing.cast(torch.Tensor, None), + ) + + self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch") + + self.assertTrue( + torch.equal(output, expected_output), + f"Values don't match: got {output}, expected {expected_output}", + )