Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# pyre-strict

from typing import Optional

import torch
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library
Expand Down Expand Up @@ -177,6 +179,54 @@ def quantized_add(
)


@impl(m, "quantized_linear")
def quantized_linear(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_zero_point: int,
weight_zero_point: torch.Tensor,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_zero_point: int,
offset: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Quantized linear (transposed matmul) operation.

Args:
- src (Tensor): The activations tensor
- weight (Tensor): The weight tensor
- bias (Tensor): The bias tensor
- in_zero_point (int): The quantized mapping of zero for the input
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
- out_multiplier (Tensor): The multiplier used to scale the output
- out_shift (Tensor): The shift used to scale the output
- out_zero_point (int): The quantized mapping of zero for the output
- offset (Tensor): Unused
"""
out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])

N, K = weight.shape

leading_dims = src.shape[:-1]
src = src.view(-1, K)

dtype = src.dtype
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
if dtype not in supported_dtypes:
raise ValueError(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}"
)

out = torch.nn.functional.linear(
src - in_zero_point, weight - weight_zero_point, bias
)
return quantize_per_tensor(
out, out_scale, out_zero_point, -128, 127, dtype
).reshape(*leading_dims, N)


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
103 changes: 102 additions & 1 deletion backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict

import typing
import unittest

import numpy as np
import torch

from executorch.backends.cadence.aot.ref_implementations import (
dequantize_per_tensor,
quantize_per_tensor,
quantized_add,
quantized_linear,
)
from executorch.backends.cadence.aot.typing_stubs import expand

Expand Down Expand Up @@ -138,3 +140,102 @@ def test_quantized_add(
torch.equal(output, expected_output),
f"Values don't match in {name}: got {output}, expected {expected_output}",
)

@expand(
[
# Test case 1: 1x2 input, 1x2 weight (1 output feature)
(
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
torch.Size([1, 2]), # weight_shape: 1 output feature, 2 input features
0, # in_zero_point
torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int8), # out_shift
0, # out_zero_point
torch.tensor([[-2]], dtype=torch.int8), # expected_output
),
# Test case 2: 1x3 input, 2x3 weight (2 output features)
(
torch.Size([1, 3]), # src_shape: 1 sample, 3 input features
torch.Size([2, 3]), # weight_shape: 2 output features, 3 input features
0, # in_zero_point
torch.tensor([0, 0, 0], dtype=torch.int8), # weight_zero_point
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int8), # out_shift
0, # out_zero_point
torch.tensor([[-10, -30]], dtype=torch.int8), # expected_output
),
# Test case 3: Batch case with different dimensions
(
torch.Size([1, 2, 2]), # src_shape: batch=1, seq=2, features=2
torch.Size([3, 2]), # weight_shape: 3 output features, 2 input features
0, # in_zero_point
torch.tensor([0, 0], dtype=torch.int8), # weight_zero_point
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int8), # out_shift
0, # out_zero_point
torch.tensor(
[[[-2, -8, -14], [-6, -28, -50]]], dtype=torch.int8
), # expected_output
),
# Test case 4: Non-zero zero points
(
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
torch.Size([2, 2]), # weight_shape: 2 output feature, 1 input feature
2, # in_zero_point
torch.tensor([1, 1], dtype=torch.int8), # weight_zero_point
torch.tensor(
[268435456], dtype=torch.int32
), # out_multiplier (1.0 * 2^31)
torch.tensor([0]), # out_shift
1, # out_zero_point
torch.tensor([[-15, 25]], dtype=torch.int8), # expected_output
),
]
)
def test_quantized_linear(
self,
src_shape: torch.Size,
weight_shape: torch.Size,
in_zero_point: int,
weight_zero_point: torch.Tensor,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_zero_point: int,
expected_output: torch.Tensor,
) -> None:
src = (
torch.arange(np.product(src_shape))
.reshape(src_shape)
.to(expected_output.dtype)
)
weight = (
torch.arange(np.product(weight_shape))
.reshape(weight_shape)
.to(expected_output.dtype)
)
bias = torch.arange(weight_shape[0]).to(expected_output.dtype)
output = quantized_linear(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
out_shift,
out_zero_point,
typing.cast(torch.Tensor, None),
)

self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")

self.assertTrue(
torch.equal(output, expected_output),
f"Values don't match: got {output}, expected {expected_output}",
)
Loading