Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 123 additions & 21 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def quantized_linear_common(
bias: torch.Tensor,
in_zero_point: int,
weight_zero_point: torch.Tensor | int,
out_multiplier: torch.Tensor | int,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
) -> torch.Tensor:
Expand Down Expand Up @@ -329,34 +329,30 @@ def variant(
assert isinstance(weight_zero_point, int)
assert isinstance(out_multiplier, int)
assert isinstance(out_shift, int)
return quantized_linear_common(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
out_shift,
out_zero_point,
)
_out_shift = out_shift
_out_multiplier = out_multiplier
else:
assert isinstance(out_shift, torch.Tensor)
assert isinstance(out_multiplier, torch.Tensor)
if out_shift.numel() != 1:
raise ValueError("out_shift must be a scalar")

if out_shift.dtype != torch.int64:
raise ValueError("out_shift must be an int64")

return quantized_linear_common(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
int(out_shift.item()),
out_zero_point,
)
_out_shift = int(out_shift.item())
_out_multiplier = int(out_multiplier[0].item())

return quantized_linear_common(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
_out_multiplier,
_out_shift,
out_zero_point,
)

return variant

Expand Down Expand Up @@ -403,6 +399,112 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor:
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_matmul")
def quantized_matmul(
X: torch.Tensor,
X_zero_point: int,
Y: torch.Tensor,
Y_zero_point: int,
bias: torch.Tensor | None,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
transposed: bool = False,
) -> torch.Tensor:
"""
Quantized matmul operation.

Args:
- X (Tensor): The activations tensor
- X_zero_point (int): The quantized mapping of zero for the input
- Y (Tensor): The weight tensor
- Y_zero_point (int): The quantized mapping of zero for the weight
- bias (Tensor): The bias tensor
- out_multiplier (int): The multiplier used to scale the output
- out_shift (int): The shift used to scale the output
- out_zero_point (int): The quantized mapping of zero for the output
- transposed (bool): Whether to transpose the weight tensor
"""
if bias is not None and not torch.all(bias == 0):
raise ValueError("bias must be None or all zeros since unused in out variant")

# Looks weird, but quantized linear assumes weights are pre-transposed,
# hence we transpose only if `transposed` is False.
if not transposed:
Y = Y.T

return quantized_linear_common(
X,
Y,
bias or torch.zeros(1, dtype=torch.int32),
X_zero_point,
Y_zero_point,
out_multiplier,
out_shift,
out_zero_point,
)


@impl(m, "quantized_matmul_asym8sxasym8s_asym8s")
def quantized_matmul_asym8sxasym8s_asym8s(
X: torch.Tensor,
X_zero_point: int,
Y: torch.Tensor,
Y_zero_point: int,
bias: torch.Tensor | None,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
transposed: bool = False,
) -> torch.Tensor:
if X.dtype != torch.int8:
raise ValueError("X dtype must be torch.int8")
if Y.dtype != torch.int8:
raise ValueError("Y dtype must be torch.int8")

return quantized_matmul(
X,
X_zero_point,
Y,
Y_zero_point,
bias,
out_multiplier,
out_shift,
out_zero_point,
transposed,
)


@impl(m, "quantized_matmul_asym8uxasym8u_asym8u")
def quantized_matmul_asym8uxasym8u_asym8u(
X: torch.Tensor,
X_zero_point: int,
Y: torch.Tensor,
Y_zero_point: int,
bias: torch.Tensor | None,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
transposed: bool = False,
) -> torch.Tensor:
if X.dtype != torch.uint8:
raise ValueError("X dtype must be torch.uint8")
if Y.dtype != torch.uint8:
raise ValueError("Y dtype must be torch.uint8")

return quantized_matmul(
X,
X_zero_point,
Y,
Y_zero_point,
bias,
out_multiplier,
out_shift,
out_zero_point,
transposed,
)


@impl(m, "quantized_layer_norm.per_tensor")
def quantized_layer_norm_per_tensor(
input_tensor: torch.Tensor,
Expand Down
125 changes: 102 additions & 23 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def test_quantized_add(
0, # out_zero_point
torch.tensor([[-2]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.int8),
Expand All @@ -200,6 +202,8 @@ def test_quantized_add(
0, # out_zero_point
torch.tensor([[-10, -30]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.int8),
Expand All @@ -225,6 +229,8 @@ def test_quantized_add(
[[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype
), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.int8),
Expand All @@ -248,6 +254,8 @@ def test_quantized_add(
1, # out_zero_point
torch.tensor([[-15, 25]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.int8),
Expand All @@ -271,6 +279,8 @@ def test_quantized_add(
1, # out_zero_point
torch.tensor([[-23, 17]], dtype=dtype), # expected_output
False,
False,
False,
)
for dtype in (torch.int8, torch.uint8)
],
Expand All @@ -292,9 +302,34 @@ def test_quantized_add(
1, # out_zero_point
torch.tensor([[-7, 13]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in ((False, torch.int8), (True, torch.int8))
],
*[
(
torch.Size([1, 2]), # src_shape: 1 sample, 2 input features
torch.Size(
[2, 2]
), # weight_shape: 2 output features, 2 input features
2, # in_zero_point
torch.tensor([1, 1], dtype=dtype), # weight_zero_point
torch.tensor(
[268435456], dtype=torch.int32
), # out_multiplier (0.125 * 2^31)
torch.tensor(
[1], dtype=torch.int64
), # out_shift (shift=1, doubles the scale)
1, # out_zero_point
torch.tensor([[-7, 17]], dtype=dtype), # expected_output
per_tensor,
matmul,
transposed_matmul,
)
for (matmul, transposed_matmul) in ((True, False), (True, True))
for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8))
],
]
)
def test_quantized_linear(
Expand All @@ -308,7 +343,12 @@ def test_quantized_linear(
out_zero_point: int,
expected_output: torch.Tensor,
per_tensor: bool,
matmul: bool,
transposed_matmul: bool,
) -> None:
if not per_tensor and matmul:
self.skipTest("Only per_tensor supported for matmul")

src = (
torch.arange(np.prod(src_shape))
.reshape(src_shape)
Expand All @@ -319,7 +359,9 @@ def test_quantized_linear(
.reshape(weight_shape)
.to(expected_output.dtype)
)
bias = torch.arange(weight_shape[0]).to(torch.int32)
if matmul and not transposed_matmul:
weight = weight.T

if per_tensor:
weight_zero_point = weight_zero_point[0]
out_multiplier = out_multiplier[0]
Expand All @@ -328,38 +370,75 @@ def test_quantized_linear(
if per_tensor:
match expected_output.dtype:
case torch.int8:
linear_ops = (
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
)
if matmul:
linear_ops = (
# Doesn't have per tensor name, but it is per tensor
torch.ops.cadence.quantized_matmul_asym8sxasym8s_asym8s,
)
else:
linear_ops = (
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
)
case torch.uint8:
linear_ops = (
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
)
if matmul:
linear_ops = (
torch.ops.cadence.quantized_matmul_asym8uxasym8u_asym8u,
)
else:
linear_ops = (
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
)
case _:
linear_ops = (
torch.ops.cadence.quantized_linear.per_tensor,
torch.ops.cadence.quantized_fully_connected.per_tensor,
)
if matmul:
linear_ops = (torch.ops.cadence.quantized_matmul,)
else:
linear_ops = (
torch.ops.cadence.quantized_linear.per_tensor,
torch.ops.cadence.quantized_fully_connected.per_tensor,
)
else:
linear_ops = (
torch.ops.cadence.quantized_linear,
torch.ops.cadence.quantized_fully_connected,
)

for linear_op in linear_ops:
output = linear_op(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
out_shift,
out_zero_point,
typing.cast(torch.Tensor, None),
# Get the function name for linear_op for debugging
op_name = (
linear_op.__name__ if hasattr(linear_op, "__name__") else str(linear_op)
)
if matmul:
assert "quantized_matmul" in op_name
output = linear_op(
src,
in_zero_point,
weight,
weight_zero_point,
None,
out_multiplier,
out_shift,
out_zero_point,
transposed_matmul,
)
else:
assert (
"quantized_linear" in op_name
or "quantized_fully_connected" in op_name
)
bias = torch.arange(weight_shape[0]).to(torch.int32)
output = linear_op(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
out_shift,
out_zero_point,
typing.cast(torch.Tensor, None),
)

self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")

Expand Down
Loading