Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def quantize_per_tensor(
]
if dtype not in supported_quant_types:
raise ValueError(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_quant_types}"
)

return torch.ops.quantized_decomposed.quantize_per_tensor(
Expand Down Expand Up @@ -264,7 +264,7 @@ def quantized_linear_common(
supported_dtypes = [torch.int8, torch.uint8, torch.int32]
if dtype not in supported_dtypes:
raise ValueError(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_dtypes}"
f"Unsupported dtype to quantize to {dtype}. Supported dtypes must be one of {supported_dtypes}"
)

out = torch.nn.functional.linear(
Expand Down Expand Up @@ -427,25 +427,27 @@ def quantized_matmul(
- out_multiplier (int): The multiplier used to scale the output
- out_shift (int): The shift used to scale the output
- out_zero_point (int): The quantized mapping of zero for the output
- transposed (bool): Whether to transpose the weight tensor
- transposed (bool): Whether Y is transposed.
"""
if bias is not None and not torch.all(bias == 0):
raise ValueError("bias must be None or all zeros since unused in out variant")

# Looks weird, but quantized linear assumes weights are pre-transposed,
# hence we transpose only if `transposed` is False.
if not transposed:
Y = Y.T
if transposed:
Y = Y.transpose(-1, -2)

return quantized_linear_common(
X,
Y,
bias or torch.zeros(1, dtype=torch.int32),
X_zero_point,
Y_zero_point,
out_multiplier,
out_shift,
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))

out = torch.matmul(
(X - X_zero_point).float(),
(Y - Y_zero_point).float(),
)
return quantize_per_tensor(
out,
out_scale,
out_zero_point,
torch.iinfo(X.dtype).min,
torch.iinfo(X.dtype).max,
X.dtype,
)


Expand Down
25 changes: 24 additions & 1 deletion backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,29 @@ def test_quantized_add(
for (matmul, transposed_matmul) in ((True, False), (True, True))
for (per_tensor, dtype) in ((True, torch.int8),)
],
*[
(
torch.Size([2, 1, 2]), # src_shape: 1 sample, 2 input features
torch.Size(
[2, 2, 2]
), # weight_shape: 2 output features, 2 input features
2, # in_zero_point
torch.tensor([1, 1], dtype=dtype), # weight_zero_point
torch.tensor(
[268435456], dtype=torch.int32
), # out_multiplier (0.125 * 2^31)
torch.tensor(
[1], dtype=torch.int32
), # out_shift (shift=1, doubles the scale)
1, # out_zero_point
torch.tensor([[[1, 2]], [[0, -1]]], dtype=dtype), # expected_output
per_tensor,
matmul,
transposed_matmul,
)
for (matmul, transposed_matmul) in ((True, False), (True, True))
for (per_tensor, dtype) in ((True, torch.int8),)
],
]
)
def test_quantized_linear(
Expand Down Expand Up @@ -380,7 +403,7 @@ def test_quantized_linear(
.to(expected_output.dtype)
)
if matmul and not transposed_matmul:
weight = weight.T
weight = weight.transpose(-1, -2)

if per_tensor:
weight_zero_point = weight_zero_point[0]
Expand Down
Loading