Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _validate_ref_impl_exists() -> None:
"cadence::dequantize_per_tensor_asym16u",
"cadence::linalg_vector_norm",
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
"cadence::quantized_w8a32_conv",
"cadence::quantize_per_tensor_asym32s",
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
"cadence::linalg_svd",
Expand Down Expand Up @@ -2753,7 +2752,10 @@ def quantized_w8a32_conv_meta(
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
assert len(src.shape) == 3

kernel_size, out_channels, in_channels = weight.shape
out_channels, in_channels, kernel_size = weight.shape
assert kernel_size == 3
assert (out_channels % 4) == 0
assert (in_channels % 4) == 0
assert in_channels == src.shape[-1]

# Compute the output tensor size
Expand Down
42 changes: 42 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,48 @@ def quantized_conv2d_nchw_per_tensor(
)


@impl_tracked(m, "quantized_w8a32_conv")
def quantized_w8a32_conv(
src: torch.Tensor,
weight: torch.Tensor,
w_scale: float,
bias: torch.Tensor,
b_scale: float,
) -> torch.Tensor:

if len(weight.shape) != 3:
raise ValueError("Weight tensor must be 3D")

out_channels, in_channels, kernel_size = weight.shape
if kernel_size != 3:
raise ValueError("Kernel size must be 3")
if (out_channels % 4) != 0:
raise ValueError("Out channels must be a multiple of 4")
if (in_channels % 4) != 0:
raise ValueError("In channels must be a multiple of 4")

# src comes in shape [batch, in_channel, in_length]
# weight comes in shape [out_ch, in_ch, kernel_dim]
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
# Dequantize weight using scale
dequant_weight = weight.float() * w_scale

# Dequantize bias using scale
dequant_bias = bias.float() * b_scale

# Perform 1D convolution
# src: [batch, in_channel, in_length]
# weight: [out_ch, in_ch, kernel_dim]
# bias: [out_ch]
output = torch.nn.functional.conv1d(
src.float(),
dequant_weight,
dequant_bias,
)

return output


@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
def quantized_conv2d_nhwc_per_tensor(
input_tensor: torch.Tensor,
Expand Down
196 changes: 196 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,202 @@ def test_quantized_conv_per_tensor(
f"Output values don't match expected. Got {output}, expected {expected_output}",
)

@expand(
[
(
"basic_int8_weights",
torch.tensor(
[
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
]
],
dtype=torch.float32,
), # src: 1x4x5
torch.tensor(
[
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
[[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]],
],
dtype=torch.int8,
), # weight: 4x4x3
0.1, # w_scale
torch.tensor([1, 1, 1, 1], dtype=torch.int8), # bias: 4
0.2, # b_scale
torch.tensor(
[
[
[2.2, 3.0, 3.8],
[2.2, 3.0, 3.8],
[2.2, 3.0, 3.8],
[2.2, 3.0, 3.8],
]
],
dtype=torch.float32,
), # expected: conv1d result
),
(
"batch_size_2",
torch.tensor(
[
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
],
[
[2.0, 3.0, 4.0, 5.0, 6.0],
[2.0, 3.0, 4.0, 5.0, 6.0],
[2.0, 3.0, 4.0, 5.0, 6.0],
[2.0, 3.0, 4.0, 5.0, 6.0],
],
],
dtype=torch.float32,
), # src: 2x4x5
torch.tensor(
[
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
[[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]],
],
dtype=torch.int8,
), # weight: 4x4x3
1.0, # w_scale
torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4
1.0, # b_scale
torch.tensor(
[
[
[24.0, 36.0, 48.0],
[24.0, 36.0, 48.0],
[24.0, 36.0, 48.0],
[24.0, 36.0, 48.0],
],
[
[36.0, 48.0, 60.0],
[36.0, 48.0, 60.0],
[36.0, 48.0, 60.0],
[36.0, 48.0, 60.0],
],
],
dtype=torch.float32,
), # expected
),
(
"zero_weights_bias",
torch.tensor(
[
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
]
],
dtype=torch.float32,
), # src: 1x4x5
torch.tensor(
[
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
],
dtype=torch.int8,
), # weight: 4x4x3
0.1, # w_scale
torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4
1.0, # b_scale
torch.tensor(
[
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
]
],
dtype=torch.float32,
), # expected
),
(
"negative_weights",
torch.tensor(
[
[
[2.0, 4.0, 6.0, 8.0, 10.0],
[2.0, 4.0, 6.0, 8.0, 10.0],
[2.0, 4.0, 6.0, 8.0, 10.0],
[2.0, 4.0, 6.0, 8.0, 10.0],
]
],
dtype=torch.float32,
), # src: 1x4x5
torch.tensor(
[
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
[[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]],
],
dtype=torch.int8,
), # weight: 4x4x3
0.5, # w_scale
torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4
1.0, # b_scale
torch.tensor(
[
[
[-14.0, -26.0, -38.0],
[-14.0, -26.0, -38.0],
[-14.0, -26.0, -38.0],
[-14.0, -26.0, -38.0],
]
],
dtype=torch.float32,
), # expected
),
]
)
def test_quantized_w8a32_conv(
self,
name: str,
src: torch.Tensor,
weight: torch.Tensor,
w_scale: float,
bias: torch.Tensor,
b_scale: float,
expected_output: torch.Tensor,
) -> None:
output = torch.ops.cadence.quantized_w8a32_conv(
src, weight, w_scale, bias, b_scale
)

# Verify output properties
self.assertEqual(
output.dtype,
torch.float32,
f"Output dtype should be float32 in {name}",
)
self.assertEqual(
output.shape,
expected_output.shape,
f"Output shape should match expected shape in {name}",
)

# Verify output matches expected values
self.assertTrue(
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

@expand(
[
# Test case 1: Basic int8 case with negative scale
Expand Down
Loading