Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,7 @@ def quantized_fully_connected_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand All @@ -1793,6 +1794,7 @@ def quantized_fully_connected_per_tensor_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand All @@ -1815,6 +1817,7 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand All @@ -1837,6 +1840,7 @@ def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_meta(
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert src.shape[0] == 1
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
Expand Down
33 changes: 29 additions & 4 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def quantized_linear_common(

def quantized_linear_variant(
per_tensor: bool,
fully_connected: bool,
src_dtype: torch.dtype | None = None,
weight_dtype: torch.dtype | None = None,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
Expand All @@ -265,6 +266,10 @@ def variant(
out_zero_point: int,
offset: torch.Tensor | None = None,
) -> torch.Tensor:
if fully_connected and src.shape[0] != 1:
raise ValueError(
"Fully connected quantized linear only supports batch size of 1"
)
if src_dtype and src.dtype != src_dtype:
raise ValueError(
f"src dtype must be {src_dtype}. Got {src.dtype} instead"
Expand Down Expand Up @@ -317,25 +322,45 @@ def variant(


@impl(m, "quantized_linear")
@quantized_linear_variant(False)
@quantized_linear_variant(False, False)
def quantized_linear() -> torch.Tensor: ...


@impl(m, "quantized_linear.per_tensor")
@quantized_linear_variant(True)
@quantized_linear_variant(True, False)
def quantized_linear_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_linear_asym8sxasym8s_asym8s.per_tensor")
@quantized_linear_variant(True, torch.int8, torch.int8)
@quantized_linear_variant(True, False, torch.int8, torch.int8)
def quantized_linear_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_linear_asym8uxasym8u_asym8u.per_tensor")
@quantized_linear_variant(True, torch.uint8, torch.uint8)
@quantized_linear_variant(True, False, torch.uint8, torch.uint8)
def quantized_linear_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_fully_connected")
@quantized_linear_variant(False, True)
def quantized_fully_connected() -> torch.Tensor: ...


@impl(m, "quantized_fully_connected.per_tensor")
@quantized_linear_variant(True, True)
def quantized_fully_connected_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor")
@quantized_linear_variant(True, True, torch.int8, torch.int8)
def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor")
@quantized_linear_variant(True, True, torch.uint8, torch.uint8)
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_layer_norm.per_tensor")
def quantized_layer_norm_per_tensor(
input_tensor: torch.Tensor,
Expand Down
53 changes: 31 additions & 22 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,36 +307,45 @@ def test_quantized_linear(
if per_tensor:
match expected_output.dtype:
case torch.int8:
linear_op = (
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor
linear_ops = (
torch.ops.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
torch.ops.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
)
case torch.uint8:
linear_op = (
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor
linear_ops = (
torch.ops.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
torch.ops.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
)
case _:
linear_op = torch.ops.cadence.quantized_linear.per_tensor
linear_ops = (
torch.ops.cadence.quantized_linear.per_tensor,
torch.ops.cadence.quantized_fully_connected.per_tensor,
)
else:
linear_op = torch.ops.cadence.quantized_linear
linear_ops = (
torch.ops.cadence.quantized_linear,
torch.ops.cadence.quantized_fully_connected,
)

output = linear_op(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
out_shift,
out_zero_point,
typing.cast(torch.Tensor, None),
)
for linear_op in linear_ops:
output = linear_op(
src,
weight,
bias,
in_zero_point,
weight_zero_point,
out_multiplier,
out_shift,
out_zero_point,
typing.cast(torch.Tensor, None),
)

self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")
self.assertTrue(output.dtype == expected_output.dtype, "Dtype mismatch")

self.assertTrue(
torch.equal(output, expected_output),
f"Values don't match: got {output}, expected {expected_output}",
)
self.assertTrue(
torch.equal(output, expected_output),
f"Values don't match: got {output}, expected {expected_output}",
)

@expand(
[
Expand Down
6 changes: 3 additions & 3 deletions backends/cadence/aot/tests/test_type_dispatch_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class TestTypeDispatchPasses(unittest.TestCase):
def test_int8_dispatch_quantized_fully_connected(self) -> None:
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
x = torch.randint(-128, 127, (1, 3), dtype=torch.int8)
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
Expand All @@ -46,7 +46,7 @@ def test_int8_dispatch_quantized_fully_connected(self) -> None:

def test_uint8_dispatch_quantized_fully_connected(self) -> None:
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant"""
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
x = torch.randint(0, 255, (1, 3), dtype=torch.uint8)
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_uint8_quantized_linear_dispatch(self) -> None:

def test_mixed_types_error(self) -> None:
"""Test mixed int8/uint8 inputs should raise RuntimeError"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
x = torch.randint(-128, 127, (1, 3), dtype=torch.int8)
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
Expand Down
Loading