Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None:
# 1. be removed
# 2. have a reference implementation added to ref_implementations.py
_WARN_ONLY = {
"cadence::quantized_w8a32_linear",
"cadence::quantized_add", # We should only support per_tensor variant, should remove
"cadence::_softmax_f32_f32",
"cadence::requantize", # We should only support per_tensor variant, should remove
Expand Down Expand Up @@ -2706,6 +2705,9 @@ def quantized_w8a32_linear_meta(
# output comes in empty with shape [leading_dims, out_dim]
src_shape = list(src.shape)
weight_shape = weight.shape
assert (src_shape[-1] % 4) == 0
if len(src_shape) >= 2:
assert src_shape[-2] == 1
assert len(weight_shape) == 2
assert src_shape[-1] == weight_shape[-1]
src_shape[-1] = weight_shape[0]
Expand All @@ -2720,12 +2722,12 @@ def quantized_w8a32_conv_meta(
bias: torch.Tensor,
b_scale: float,
) -> torch.Tensor:
# src comes in shape [batch, in_channel, in_length]
# weight comes in shape [out_ch, in_ch, kernel_dim]
# src comes in shape [batch, in_length, in_channels]
# weight comes in shape [kernel_dim, out_ch, in_ch]
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
assert len(src.shape) == 3

out_channels, in_channels, kernel_size = weight.shape
kernel_size, out_channels, in_channels = weight.shape
assert kernel_size == 3
assert (out_channels % 4) == 0
assert (in_channels % 4) == 0
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def get_args_and_kwargs_mixed_w8a32_conv(
)
transposed_weights = graph_module.graph.call_function(
torch.ops.aten.permute.default,
(weights_inputs[0], [2, 0, 1]), # NCL -> NLC
(weights_inputs[0], [2, 0, 1]), # NCL -> LNC
)

args = (
Expand Down
45 changes: 40 additions & 5 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,18 +854,23 @@ def quantized_w8a32_conv(
if len(weight.shape) != 3:
raise ValueError("Weight tensor must be 3D")

out_channels, in_channels, kernel_size = weight.shape
kernel_size, out_channels, in_channels = weight.shape
if kernel_size != 3:
raise ValueError("Kernel size must be 3")
if (out_channels % 4) != 0:
raise ValueError("Out channels must be a multiple of 4")
if (in_channels % 4) != 0:
raise ValueError("In channels must be a multiple of 4")

# src comes in shape [batch, in_channel, in_length]
# weight comes in shape [out_ch, in_ch, kernel_dim]
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
# Dequantize weight using scale
assert weight.dtype == torch.int8
assert bias.dtype == torch.int8

# To make compliant with torch (LCN -> NCL format)
weight = weight.permute(1, 2, 0).contiguous()

# channels last to channels first
src = src.permute(0, 2, 1).contiguous()

dequant_weight = weight.float() * w_scale

# Dequantize bias using scale
Expand All @@ -884,6 +889,36 @@ def quantized_w8a32_conv(
return output


@impl_tracked(m, "quantized_w8a32_linear")
def quantized_w8a32_linear(
src: torch.Tensor,
weight: torch.Tensor,
w_scale: float,
bias: torch.Tensor,
b_scale: float,
) -> torch.Tensor:
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [in_dim, out_dim]
# output comes in empty with shape [leading_dims, out_dim]
assert weight.dtype == torch.int8
assert bias.dtype == torch.int8
if len(src.shape) >= 2:
assert src.shape[-2] == 1, "Only supporting vector-matrix multiplication"

# need to transpose to make compliant with torch linear (in, out -> out, in)
weight = weight.transpose(1, 0).contiguous()
dequant_weight = weight.float() * w_scale
dequant_bias = bias.float() * b_scale

output = torch.nn.functional.linear(
src.float(),
dequant_weight,
dequant_bias,
)

return output


@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
def quantized_conv2d_nhwc_per_tensor(
input_tensor: torch.Tensor,
Expand Down
97 changes: 96 additions & 1 deletion backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor(
dtype=torch.int8,
), # weight: 4x4x3
0.5, # w_scale
torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4
torch.tensor([2, 2, 2, 2], dtype=torch.int8), # bias: 4
1.0, # b_scale
torch.tensor(
[
Expand All @@ -1214,6 +1214,12 @@ def test_quantized_w8a32_conv(
b_scale: float,
expected_output: torch.Tensor,
) -> None:

# This op takes in channels last src
src = src.permute(0, 2, 1)

# This op takes in LNC format for weights
weight = weight.permute(2, 0, 1)
output = torch.ops.cadence.quantized_w8a32_conv(
src, weight, w_scale, bias, b_scale
)
Expand All @@ -1236,6 +1242,95 @@ def test_quantized_w8a32_conv(
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

@expand(
[
(
"multi_input_features",
torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # src: 1x3
torch.tensor([[2, 1], [1, 2], [1, 1]], dtype=torch.int8), # weight: 3x2
0.5, # w_scale
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1.0, # b_scale
torch.tensor([[3.5, 5.0]], dtype=torch.float32), # expected
),
(
"batch_size_2",
torch.tensor(
[[[1.0, 2.0]], [[3.0, 4.0]]], dtype=torch.float32
), # src: 2x2
torch.tensor([[1, 2], [1, -1]], dtype=torch.int8), # weight: 2x2
1.0, # w_scale
torch.tensor([0, 0], dtype=torch.int8), # bias: 2
1.0, # b_scale
torch.tensor(
[[[3.0, 0.0]], [[7.0, 2.0]]], dtype=torch.float32
), # expected
),
(
"shape_assertion_error",
torch.tensor(
[[[1.0, 2.0], [3.0, 4.0]]], dtype=torch.float32
), # src: 1x2x2
torch.tensor([[1, 2], [1, -1]], dtype=torch.int8), # weight: 2x2
1.0, # w_scale
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1.0, # b_scale
torch.tensor(
[[[3.0, 1.0], [7.0, 3.0]]], dtype=torch.float32
), # expected
),
(
"negative_weights",
torch.tensor([[2.0, 4.0]], dtype=torch.float32), # src: 1x2
torch.tensor([[-2, -3], [-1, -2]], dtype=torch.int8), # weight: 2x2
0.5, # w_scale
torch.tensor([2, 1], dtype=torch.int8), # bias: 2
1.0, # b_scale
torch.tensor([[-2.0, -6.0]], dtype=torch.float32), # expected
),
]
)
def test_quantized_w8a32_linear(
self,
name: str,
src: torch.Tensor,
weight: torch.Tensor,
w_scale: float,
bias: torch.Tensor,
b_scale: float,
expected_output: torch.Tensor,
) -> None:
if name == "shape_assertion_error":
with self.assertRaisesRegex(
AssertionError, "Only supporting vector-matrix multiplication"
):
torch.ops.cadence.quantized_w8a32_linear(
src, weight, w_scale, bias, b_scale
)
return

output = torch.ops.cadence.quantized_w8a32_linear(
src, weight, w_scale, bias, b_scale
)

# Verify output properties
self.assertEqual(
output.dtype,
torch.float32,
f"Output dtype should be float32 in {name}",
)
self.assertEqual(
output.shape,
expected_output.shape,
f"Output shape should match expected shape in {name}",
)

# Verify output matches expected values
self.assertTrue(
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
)

@expand(
[
# Test case 1: Basic int8 case with negative scale
Expand Down
Loading