Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,48 @@ def quantized_add_per_tensor(
)


@impl(m, "quantized_add_asym8sxasym8s_asym8s.per_tensor")
def quantized_add_asym8sxasym8s_asym8s_per_tensor(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
if X.dtype != torch.int8:
raise ValueError("X dtype must be torch.int8")
if Y.dtype != torch.int8:
raise ValueError("Y dtype must be torch.int8")

return quantized_add_per_tensor(
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
)


@impl(m, "quantized_add_asym8uxasym8u_asym8u.per_tensor")
def quantized_add_asym8uxasym8u_asym8u_per_tensor(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
if X.dtype != torch.uint8:
raise ValueError("X dtype must be torch.int8")
if Y.dtype != torch.uint8:
raise ValueError("Y dtype must be torch.int8")

return quantized_add_per_tensor(
X, X_scale, X_zero_point, Y, Y_scale, Y_zero_point, out_scale, out_zero_point
)


def quantized_linear_common(
src: torch.Tensor,
weight: torch.Tensor,
Expand Down
23 changes: 22 additions & 1 deletion backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_dequantize_per_tensor(
[
# Only these types need to be tested as per ET_FORALL_JARVIS_QUANTIZED_TYPES in
# on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h
("int16", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8),
("int8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.int8),
("uint8", 5, 0.8, 4, 5, 0.8, 4, 0.8, 4, 6, torch.uint8),
]
)
Expand All @@ -122,6 +122,27 @@ def test_quantized_add(
Y_tensor = torch.tensor([Y], dtype=dtype)
expected_output = torch.tensor([expected_value], dtype=dtype)

quantized_add = (
torch.ops.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor
if dtype == torch.int8
else torch.ops.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor
)
output = quantized_add(
X_tensor,
X_scale,
X_zero_point,
Y_tensor,
Y_scale,
Y_zero_point,
out_scale,
out_zero_point,
)

self.assertTrue(
torch.equal(output, expected_output),
f"Values don't match in {name}: got {output}, expected {expected_output}",
)

output = torch.ops.cadence.quantized_add(
X_tensor,
X_scale,
Expand Down
Loading