From 556e0e67d39bd64af8f99fe55cb059e7746a0335 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Mon, 20 Oct 2025 12:30:34 -0700 Subject: [PATCH] Cadence ops: Support quantized_w8a32_linear and update shapes for w8a32_conv (#15171) Summary: w8a32 conv has an interesting data layout. This diff corrects that, and additionally implements quantized_w8a32_linear (weights are not pre-transposed). Reviewed By: hsharma35, mcremon-meta Differential Revision: D84745967 --- backends/cadence/aot/ops_registrations.py | 10 +- backends/cadence/aot/quantizer/fusion_pass.py | 2 +- backends/cadence/aot/ref_implementations.py | 45 ++++++++- .../aot/tests/test_ref_implementations.py | 97 ++++++++++++++++++- 4 files changed, 143 insertions(+), 11 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index d7ec5bf05b3..030d10438fb 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None: # 1. be removed # 2. have a reference implementation added to ref_implementations.py _WARN_ONLY = { - "cadence::quantized_w8a32_linear", "cadence::quantized_add", # We should only support per_tensor variant, should remove "cadence::_softmax_f32_f32", "cadence::requantize", # We should only support per_tensor variant, should remove @@ -2706,6 +2705,9 @@ def quantized_w8a32_linear_meta( # output comes in empty with shape [leading_dims, out_dim] src_shape = list(src.shape) weight_shape = weight.shape + assert (src_shape[-1] % 4) == 0 + if len(src_shape) >= 2: + assert src_shape[-2] == 1 assert len(weight_shape) == 2 assert src_shape[-1] == weight_shape[-1] src_shape[-1] = weight_shape[0] @@ -2720,12 +2722,12 @@ def quantized_w8a32_conv_meta( bias: torch.Tensor, b_scale: float, ) -> torch.Tensor: - # src comes in shape [batch, in_channel, in_length] - # weight comes in shape [out_ch, in_ch, kernel_dim] + # src comes in shape [batch, in_length, in_channels] + # weight comes in shape [kernel_dim, out_ch, in_ch] # output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1] assert len(src.shape) == 3 - out_channels, in_channels, kernel_size = weight.shape + kernel_size, out_channels, in_channels = weight.shape assert kernel_size == 3 assert (out_channels % 4) == 0 assert (in_channels % 4) == 0 diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index a6929bd9a39..e2818f725ef 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -397,7 +397,7 @@ def get_args_and_kwargs_mixed_w8a32_conv( ) transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, - (weights_inputs[0], [2, 0, 1]), # NCL -> NLC + (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) args = ( diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index f9f7249b249..b5523427a69 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -854,7 +854,7 @@ def quantized_w8a32_conv( if len(weight.shape) != 3: raise ValueError("Weight tensor must be 3D") - out_channels, in_channels, kernel_size = weight.shape + kernel_size, out_channels, in_channels = weight.shape if kernel_size != 3: raise ValueError("Kernel size must be 3") if (out_channels % 4) != 0: @@ -862,10 +862,15 @@ def quantized_w8a32_conv( if (in_channels % 4) != 0: raise ValueError("In channels must be a multiple of 4") - # src comes in shape [batch, in_channel, in_length] - # weight comes in shape [out_ch, in_ch, kernel_dim] - # output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1] - # Dequantize weight using scale + assert weight.dtype == torch.int8 + assert bias.dtype == torch.int8 + + # To make compliant with torch (LCN -> NCL format) + weight = weight.permute(1, 2, 0).contiguous() + + # channels last to channels first + src = src.permute(0, 2, 1).contiguous() + dequant_weight = weight.float() * w_scale # Dequantize bias using scale @@ -884,6 +889,36 @@ def quantized_w8a32_conv( return output +@impl_tracked(m, "quantized_w8a32_linear") +def quantized_w8a32_linear( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [in_dim, out_dim] + # output comes in empty with shape [leading_dims, out_dim] + assert weight.dtype == torch.int8 + assert bias.dtype == torch.int8 + if len(src.shape) >= 2: + assert src.shape[-2] == 1, "Only supporting vector-matrix multiplication" + + # need to transpose to make compliant with torch linear (in, out -> out, in) + weight = weight.transpose(1, 0).contiguous() + dequant_weight = weight.float() * w_scale + dequant_bias = bias.float() * b_scale + + output = torch.nn.functional.linear( + src.float(), + dequant_weight, + dequant_bias, + ) + + return output + + @impl_tracked(m, "quantized_conv2d_nhwc.per_tensor") def quantized_conv2d_nhwc_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index b3886f453f5..cb4b26c59e1 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor( dtype=torch.int8, ), # weight: 4x4x3 0.5, # w_scale - torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4 + torch.tensor([2, 2, 2, 2], dtype=torch.int8), # bias: 4 1.0, # b_scale torch.tensor( [ @@ -1214,6 +1214,12 @@ def test_quantized_w8a32_conv( b_scale: float, expected_output: torch.Tensor, ) -> None: + + # This op takes in channels last src + src = src.permute(0, 2, 1) + + # This op takes in LNC format for weights + weight = weight.permute(2, 0, 1) output = torch.ops.cadence.quantized_w8a32_conv( src, weight, w_scale, bias, b_scale ) @@ -1236,6 +1242,95 @@ def test_quantized_w8a32_conv( f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + @expand( + [ + ( + "multi_input_features", + torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # src: 1x3 + torch.tensor([[2, 1], [1, 2], [1, 1]], dtype=torch.int8), # weight: 3x2 + 0.5, # w_scale + torch.tensor([0, 1], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor([[3.5, 5.0]], dtype=torch.float32), # expected + ), + ( + "batch_size_2", + torch.tensor( + [[[1.0, 2.0]], [[3.0, 4.0]]], dtype=torch.float32 + ), # src: 2x2 + torch.tensor([[1, 2], [1, -1]], dtype=torch.int8), # weight: 2x2 + 1.0, # w_scale + torch.tensor([0, 0], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor( + [[[3.0, 0.0]], [[7.0, 2.0]]], dtype=torch.float32 + ), # expected + ), + ( + "shape_assertion_error", + torch.tensor( + [[[1.0, 2.0], [3.0, 4.0]]], dtype=torch.float32 + ), # src: 1x2x2 + torch.tensor([[1, 2], [1, -1]], dtype=torch.int8), # weight: 2x2 + 1.0, # w_scale + torch.tensor([0, 1], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor( + [[[3.0, 1.0], [7.0, 3.0]]], dtype=torch.float32 + ), # expected + ), + ( + "negative_weights", + torch.tensor([[2.0, 4.0]], dtype=torch.float32), # src: 1x2 + torch.tensor([[-2, -3], [-1, -2]], dtype=torch.int8), # weight: 2x2 + 0.5, # w_scale + torch.tensor([2, 1], dtype=torch.int8), # bias: 2 + 1.0, # b_scale + torch.tensor([[-2.0, -6.0]], dtype=torch.float32), # expected + ), + ] + ) + def test_quantized_w8a32_linear( + self, + name: str, + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, + expected_output: torch.Tensor, + ) -> None: + if name == "shape_assertion_error": + with self.assertRaisesRegex( + AssertionError, "Only supporting vector-matrix multiplication" + ): + torch.ops.cadence.quantized_w8a32_linear( + src, weight, w_scale, bias, b_scale + ) + return + + output = torch.ops.cadence.quantized_w8a32_linear( + src, weight, w_scale, bias, b_scale + ) + + # Verify output properties + self.assertEqual( + output.dtype, + torch.float32, + f"Output dtype should be float32 in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + @expand( [ # Test case 1: Basic int8 case with negative scale