Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ runtime.python_library(
deps = [
"fbcode//caffe2:torch",
"fbcode//executorch/exir:scalar_type",
"fbcode//executorch/kernels/quantized:custom_ops_generated_lib",
],
)

Expand Down
88 changes: 52 additions & 36 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@

# pyre-strict


from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F

from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library


m = Library("cadence", "IMPL", "CompositeExplicitAutograd")
torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")

qdtype_map: dict[ScalarType, torch.dtype] = {
ScalarType.QINT8: torch.qint8,
Expand All @@ -38,7 +39,7 @@ def quantize_per_tensor(

Args:
- input_tensor (Tensor): input tensor
- scale (float): Inverse of quantization scale. Derived from the ratio
- scale (float): Quantization scale. Derived from the ratio
between the min/max of the floating-point tensor and the
min/max of the quantized range, and then inverted.
- zero_point (int): The point which represents 0 in the quantized
Expand All @@ -64,10 +65,13 @@ def quantize_per_tensor(
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
)

quantized = torch.round(input_tensor * scale + zero_point).to(dtype)
return torch.max(
torch.min(quantized, torch.tensor(quant_max)),
torch.tensor(quant_min),
return torch.ops.quantized_decomposed.quantize_per_tensor(
input_tensor,
scale,
zero_point,
quant_min,
quant_max,
dtype,
)


Expand Down Expand Up @@ -97,7 +101,7 @@ def dequantize_per_tensor(
is already provided.
- quant_max (int): The largest value in the quantized domain. Unused since scale
is already provided.
- dtype (torch.dtype): The type of the output tensor. Must be a floating point type.
- dtype (torch.dtype): The type of the input tensor.
"""
supported_quant_types = [
torch.int8,
Expand All @@ -108,23 +112,15 @@ def dequantize_per_tensor(
]
if input_tensor.dtype not in supported_quant_types:
raise ValueError(f"Input dtype must be one of {supported_quant_types}")
supported_dequant_types = [
torch.float,
torch.float32,
torch.float16,
torch.bfloat16,
]
if dtype not in supported_dequant_types:
raise ValueError(
f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}"
)

# Needed to prevent underflow in cases where the zero_point is larger than
# the quantized value.
if not input_tensor.dtype.is_signed:
input_tensor = input_tensor.to(torch.int32)

return (input_tensor - zero_point).to(dtype) * scale
if input_tensor.dtype != dtype:
raise ValueError("Input dtype must match dtype")

# Use the reference implementation from torch quantized_decomposed library
# Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior
# difference, since there's no rounding algorithm (just arithmetic).
return torch.ops.quantized_decomposed.dequantize_per_tensor(
input_tensor, scale, zero_point, quant_min, quant_max, dtype
)


@impl(m, "quantized_add.per_tensor")
Expand Down Expand Up @@ -180,12 +176,10 @@ def quantized_add_per_tensor(
dequant_X = X_scale * (X - X_zero_point)
dequant_Y = Y_scale * (Y - Y_zero_point)

out_scale_inv = 1 / out_scale

# q_min/q_max are unused args
return quantize_per_tensor(
dequant_X + dequant_Y,
out_scale_inv,
out_scale,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
Expand Down Expand Up @@ -259,8 +253,7 @@ def quantized_linear_common(
- out_zero_point (int): The quantized mapping of zero for the output
- offset (Tensor): Unused
"""
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
out_scale_inv = 1 / out_scale
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))

N, K = weight.shape

Expand All @@ -281,7 +274,7 @@ def quantized_linear_common(
)
return quantize_per_tensor(
out,
out_scale_inv,
out_scale,
out_zero_point,
torch.iinfo(dtype).min,
torch.iinfo(dtype).max,
Expand Down Expand Up @@ -399,6 +392,17 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor:
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "fully_connected")
def fully_connected(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if input_tensor.shape[0] != 1:
raise ValueError("Fully connected linear only supports batch size of 1")
return F.linear(input_tensor, weight, bias)


@impl(m, "quantized_matmul")
def quantized_matmul(
X: torch.Tensor,
Expand Down Expand Up @@ -538,15 +542,15 @@ def quantized_layer_norm_per_tensor(
)

float_input_tensor = dequantize_per_tensor(
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype
)
out = torch.nn.functional.layer_norm(
float_input_tensor, normalized_shape, weight, bias, eps=eps
)

return quantize_per_tensor(
out,
1 / output_scale,
output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
Expand Down Expand Up @@ -615,7 +619,7 @@ def quantized_conv_per_tensor(

return quantize_per_tensor(
float_out,
1.0 / output_scale,
output_scale,
output_zero_point,
torch.iinfo(input_tensor.dtype).min,
torch.iinfo(input_tensor.dtype).max,
Expand Down Expand Up @@ -942,8 +946,10 @@ def quantized_relu_common(
if X.dtype not in supported_dtypes:
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")

out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))
dequantized_X = torch.where(
X > X_zero_point, X - X_zero_point, torch.zeros_like(X)
).to(torch.float32)
return quantize_per_tensor(
dequantized_X,
out_scale,
Expand Down Expand Up @@ -1068,3 +1074,13 @@ def requantize(
out_quant_max,
dtype,
)


@impl(m, "rms_norm")
def rms_norm(
X: torch.Tensor,
normalized_shape: tuple[int],
W: torch.Tensor,
eps: float,
) -> torch.Tensor:
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
65 changes: 49 additions & 16 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ def test_quantize_per_tensor(
) -> None:
input_tensor = torch.tensor([input_value])
scale = (f_max - f_min) / (q_max - q_min)
inv_scale = 1.0 / scale
zero_point = round(-f_min * inv_scale) + q_min
zero_point = round(-f_min * 1 / scale) + q_min
expected_output = torch.tensor([expected_value], dtype=target_dtype)

output = torch.ops.cadence.quantize_per_tensor(
input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype
input_tensor, scale, zero_point, q_min, q_max, target_dtype
)

self.assertEqual(
Expand Down Expand Up @@ -85,7 +84,7 @@ def test_dequantize_per_tensor(
expected_output = torch.tensor([expected_value], dtype=torch.float32)

output = torch.ops.cadence.dequantize_per_tensor(
input_tensor, scale, zero_point, q_min, q_max, torch.float32
input_tensor, scale, zero_point, q_min, q_max, input_tensor.dtype
)

self.assertEqual(
Expand Down Expand Up @@ -175,7 +174,7 @@ def test_quantized_add(
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
0, # out_zero_point
torch.tensor([[-2]], dtype=dtype), # expected_output
torch.tensor([[0]], dtype=dtype), # expected_output
per_tensor,
False,
False,
Expand All @@ -200,14 +199,36 @@ def test_quantized_add(
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
0, # out_zero_point
torch.tensor([[-10, -30]], dtype=dtype), # expected_output
torch.tensor([[-2, -8]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.int8),
(True, torch.int8),
)
],
*[
(
torch.Size([1, 3]), # src_shape: 1 sample, 3 input features
torch.Size(
[2, 3]
), # weight_shape: 2 output features, 3 input features
0, # in_zero_point
torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
0, # out_zero_point
torch.tensor([[0, 0]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.uint8),
(True, torch.uint8),
)
],
Expand All @@ -226,7 +247,7 @@ def test_quantized_add(
torch.tensor([0], dtype=torch.int64), # out_shift
0, # out_zero_point
torch.tensor(
[[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype
[[[0, -2, -4], [-2, -7, -12]]], dtype=dtype
), # expected_output
per_tensor,
False,
Expand All @@ -235,7 +256,6 @@ def test_quantized_add(
for (per_tensor, dtype) in (
(False, torch.int8),
(True, torch.int8),
(True, torch.uint8),
)
],
# Test case 4: Non-zero zero points
Expand All @@ -252,15 +272,15 @@ def test_quantized_add(
), # out_multiplier (1.0 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
1, # out_zero_point
torch.tensor([[-15, 25]], dtype=dtype), # expected_output
torch.tensor([[1, 1]], dtype=dtype), # expected_output
per_tensor,
False,
False,
)
for (per_tensor, dtype) in (
(False, torch.int8),
(True, torch.int8),
(True, torch.uint8),
# (True, torch.uint8),
)
],
# Test case 5: Non-uniform weight zero points
Expand All @@ -277,12 +297,12 @@ def test_quantized_add(
), # out_multiplier (1.0 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
1, # out_zero_point
torch.tensor([[-23, 17]], dtype=dtype), # expected_output
torch.tensor([[1, 1]], dtype=dtype), # expected_output
False,
False,
False,
)
for dtype in (torch.int8, torch.uint8)
for dtype in (torch.int8,)
],
# Test case 6: Non-zero out_shift (shift=1)
*[
Expand All @@ -300,7 +320,7 @@ def test_quantized_add(
[1], dtype=torch.int64
), # out_shift (shift=1, doubles the scale)
1, # out_zero_point
torch.tensor([[-7, 13]], dtype=dtype), # expected_output
torch.tensor([[1, 2]], dtype=dtype), # expected_output
per_tensor,
False,
False,
Expand All @@ -322,13 +342,13 @@ def test_quantized_add(
[1], dtype=torch.int64
), # out_shift (shift=1, doubles the scale)
1, # out_zero_point
torch.tensor([[-7, 17]], dtype=dtype), # expected_output
torch.tensor([[1, 2]], dtype=dtype), # expected_output
per_tensor,
matmul,
transposed_matmul,
)
for (matmul, transposed_matmul) in ((True, False), (True, True))
for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8))
for (per_tensor, dtype) in ((True, torch.int8),)
],
]
)
Expand Down Expand Up @@ -1045,7 +1065,20 @@ def test_quantized_conv_per_tensor(
[4, 2, 0, -2], dtype=dtype
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
)
for dtype in [torch.int8, torch.uint8]
for dtype in [torch.int8]
],
*[
(
"positive_with_shift_unsigned",
torch.tensor([2, 4, 6, 8], dtype=dtype), # input
1, # X_zero_point
5, # out_zero_point
1073741824, # out_multiplier (0.5 * 2^31)
1, # out_shift (multiply by 2^1 = 2)
dtype, # dtype
torch.tensor([4, 2, 0, 0], dtype=dtype),
)
for dtype in [torch.uint8]
],
# Test case 4: Non-per-tensor
*[
Expand Down
Loading