Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,13 +748,12 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu")
def quantized_relu(
def quantized_relu_common(
X: torch.Tensor,
X_zero_point: torch.Tensor,
X_zero_point: torch.Tensor | int,
out_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
"""
Quantized ReLU operation followed by requantization.
Expand All @@ -770,7 +769,7 @@ def quantized_relu(
if X.dtype not in supported_dtypes:
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")

out_scale = -out_multiplier * (1 / (1 << 31)) * (2 ** out_shift[0])
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
return quantize_per_tensor(
dequantized_X,
Expand All @@ -782,6 +781,79 @@ def quantized_relu(
)


def quantized_relu_variant(
per_tensor: bool,
dtype: torch.dtype | None = None,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Create a quantized relu variant with type checking."""

def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def variant(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
out_zero_point: int,
out_multiplier: torch.Tensor | int,
out_shift: torch.Tensor | int,
) -> torch.Tensor:
if per_tensor:
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

assert isinstance(out_shift, int)
assert isinstance(out_multiplier, int)
_out_shift = out_shift
_out_multiplier = out_multiplier
else:
assert isinstance(out_multiplier, torch.Tensor)
if out_multiplier.numel() > 1:
raise ValueError("Only scalar out_multiplier is supported")

assert isinstance(out_shift, torch.Tensor)
if out_shift.numel() > 1:
raise ValueError("Only scalar out_shift is supported")

assert isinstance(X_zero_point, torch.Tensor)
if X_zero_point.shape != X.shape:
raise ValueError(
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
)

_out_multiplier = int(out_multiplier.item())
_out_shift = int(out_shift.item())

return quantized_relu_common(
X,
X_zero_point,
out_zero_point,
_out_multiplier,
_out_shift,
)

return variant

return decorator


@impl(m, "quantized_relu")
@quantized_relu_variant(False)
def quantized_relu() -> torch.Tensor: ...


@impl(m, "quantized_relu.per_tensor")
@quantized_relu_variant(True)
def quantized_relu_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
@quantized_relu_variant(True, torch.int8)
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
@quantized_relu_variant(True, torch.uint8)
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "requantize")
def requantize(
input: torch.Tensor,
Expand Down
159 changes: 105 additions & 54 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,73 +884,124 @@ def test_quantized_conv_per_tensor(
@expand(
[
# Test case 1: Basic int8 case with negative scale
(
"basic_int8",
torch.tensor([-1, 0, 1, 3], dtype=torch.int8), # input
torch.tensor([0], dtype=torch.int8), # X_zero_point (scalar broadcast)
0, # out_zero_point
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
torch.tensor([0]), # out_shift
torch.int8, # dtype
torch.tensor(
[0, 0, 0, -2], dtype=torch.int8
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
),
*[
(
"basic_int8",
torch.tensor([-1, 0, 1, 3], dtype=dtype), # input
0, # X_zero_point (scalar broadcast)
0, # out_zero_point
1073741824, # out_multiplier (0.5 * 2^31)
0, # out_shift
dtype, # dtype
torch.tensor(
[0, 0, 0, -2], dtype=dtype
), # expected: relu(-1,0,1,3) = (0,0,1,3) * (-0.5) + 0 = (0,0,-0.5,-1.5) -> (0,0,0,-2)
)
for dtype in [torch.int8]
],
# Test case 2: uint8 with non-zero zero point
(
"uint8_with_zp",
torch.tensor([126, 128, 130, 132], dtype=torch.uint8), # input
torch.tensor([128], dtype=torch.uint8), # X_zero_point
64, # out_zero_point
torch.tensor([536870912]), # out_multiplier (0.25 * 2^31)
torch.tensor([0]), # out_shift
torch.uint8, # dtype
torch.tensor(
[64, 64, 64, 63], dtype=torch.uint8
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
),
*[
(
"uint8_with_zp",
torch.tensor([126, 128, 130, 132], dtype=dtype), # input
128, # X_zero_point
64, # out_zero_point
536870912, # out_multiplier (0.25 * 2^31)
0, # out_shift
dtype, # dtype
torch.tensor(
[64, 64, 64, 63], dtype=dtype
), # expected: relu(-2,0,2,4) = (0,0,2,4) * (-0.25) + 64 = (64,64,63.5,63) -> (64,64,64,63)
)
for dtype in [torch.uint8]
],
# Test case 3: All negative values (should all become zero after ReLU)
(
"all_negative_int8",
torch.tensor([-5, -3, -1], dtype=torch.int8), # input
torch.tensor([0], dtype=torch.int8), # X_zero_point
10, # out_zero_point
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
torch.tensor([0]), # out_shift
torch.int8, # dtype
torch.tensor(
[10, 10, 10], dtype=torch.int8
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
),
*[
(
"all_negative_int8",
torch.tensor([-5, -3, -1], dtype=dtype), # input
0, # X_zero_point
10, # out_zero_point
1073741824, # out_multiplier (0.5 * 2^31)
0, # out_shift
dtype, # dtype
torch.tensor(
[10, 10, 10], dtype=dtype
), # expected: relu(-5,-3,-1) = (0,0,0) * (-0.5) + 10 = (10,10,10)
)
for dtype in [torch.int8]
],
# Test case 4: All positive values with shift (scale becomes -0.25)
(
"positive_with_shift",
torch.tensor([2, 4, 6, 8], dtype=torch.int8), # input
torch.tensor([1], dtype=torch.int8), # X_zero_point
5, # out_zero_point
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
torch.int8, # dtype
torch.tensor(
[4, 2, 0, -2], dtype=torch.int8
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
),
*[
(
"positive_with_shift",
torch.tensor([2, 4, 6, 8], dtype=dtype), # input
1, # X_zero_point
5, # out_zero_point
1073741824, # out_multiplier (0.5 * 2^31)
1, # out_shift (multiply by 2^1 = 2)
dtype, # dtype
torch.tensor(
[4, 2, 0, -2], dtype=dtype
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
)
for dtype in [torch.int8, torch.uint8]
],
# Test case 4: Non-per-tensor
*[
(
"non_per_tensor",
torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input
torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point
5, # out_zero_point
torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31)
torch.tensor([1]), # out_shift (multiply by 2^1 = 2)
dtype, # dtype
torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype),
)
for dtype in [torch.int8]
],
]
)
def test_quantized_relu(
self,
name: str,
X: torch.Tensor,
X_zero_point: torch.Tensor,
X_zero_point: torch.Tensor | int,
out_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_multiplier: torch.Tensor | int,
out_shift: torch.Tensor | int,
dtype: torch.dtype,
expected_output: torch.Tensor,
) -> None:
output = torch.ops.cadence.quantized_relu(
X, X_zero_point, out_zero_point, out_multiplier, out_shift
)

if isinstance(X_zero_point, int):
assert isinstance(out_multiplier, int)
assert isinstance(out_shift, int)

match dtype:
case torch.int8:
quantized_relu = (
torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor
)
case torch.uint8:
quantized_relu = (
torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor
)
case _:
quantized_relu = torch.ops.cadence.quantized_relu_per_tensor

output = quantized_relu(
X,
X_zero_point,
out_zero_point,
out_multiplier,
out_shift,
)
else:
output = torch.ops.cadence.quantized_relu(
X, X_zero_point, out_zero_point, out_multiplier, out_shift
)

# Verify output properties
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")
Expand Down
Loading