diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 572a19ca872..854b2137ae7 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -67,7 +67,6 @@ def _validate_ref_impl_exists() -> None: "cadence::dequantize_per_tensor_asym16u", "cadence::linalg_vector_norm", "cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove - "cadence::quantized_w8a32_conv", "cadence::quantize_per_tensor_asym32s", "cadence::quantized_relu", # We should only support per_tensor variant, should remove "cadence::linalg_svd", @@ -2753,7 +2752,10 @@ def quantized_w8a32_conv_meta( # output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1] assert len(src.shape) == 3 - kernel_size, out_channels, in_channels = weight.shape + out_channels, in_channels, kernel_size = weight.shape + assert kernel_size == 3 + assert (out_channels % 4) == 0 + assert (in_channels % 4) == 0 assert in_channels == src.shape[-1] # Compute the output tensor size diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 90f39089edc..afa2f3b1884 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -703,6 +703,48 @@ def quantized_conv2d_nchw_per_tensor( ) +@impl_tracked(m, "quantized_w8a32_conv") +def quantized_w8a32_conv( + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, +) -> torch.Tensor: + + if len(weight.shape) != 3: + raise ValueError("Weight tensor must be 3D") + + out_channels, in_channels, kernel_size = weight.shape + if kernel_size != 3: + raise ValueError("Kernel size must be 3") + if (out_channels % 4) != 0: + raise ValueError("Out channels must be a multiple of 4") + if (in_channels % 4) != 0: + raise ValueError("In channels must be a multiple of 4") + + # src comes in shape [batch, in_channel, in_length] + # weight comes in shape [out_ch, in_ch, kernel_dim] + # output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1] + # Dequantize weight using scale + dequant_weight = weight.float() * w_scale + + # Dequantize bias using scale + dequant_bias = bias.float() * b_scale + + # Perform 1D convolution + # src: [batch, in_channel, in_length] + # weight: [out_ch, in_ch, kernel_dim] + # bias: [out_ch] + output = torch.nn.functional.conv1d( + src.float(), + dequant_weight, + dequant_bias, + ) + + return output + + @impl_tracked(m, "quantized_conv2d_nhwc.per_tensor") def quantized_conv2d_nhwc_per_tensor( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 6aa091147c7..8d910d29e52 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1040,6 +1040,202 @@ def test_quantized_conv_per_tensor( f"Output values don't match expected. Got {output}, expected {expected_output}", ) + @expand( + [ + ( + "basic_int8_weights", + torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + ] + ], + dtype=torch.float32, + ), # src: 1x4x5 + torch.tensor( + [ + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + [[1, -1, 2], [1, -1, 2], [1, -1, 2], [1, -1, 2]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 0.1, # w_scale + torch.tensor([1, 1, 1, 1], dtype=torch.int8), # bias: 4 + 0.2, # b_scale + torch.tensor( + [ + [ + [2.2, 3.0, 3.8], + [2.2, 3.0, 3.8], + [2.2, 3.0, 3.8], + [2.2, 3.0, 3.8], + ] + ], + dtype=torch.float32, + ), # expected: conv1d result + ), + ( + "batch_size_2", + torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + ], + [ + [2.0, 3.0, 4.0, 5.0, 6.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + [2.0, 3.0, 4.0, 5.0, 6.0], + ], + ], + dtype=torch.float32, + ), # src: 2x4x5 + torch.tensor( + [ + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 1.0, # w_scale + torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4 + 1.0, # b_scale + torch.tensor( + [ + [ + [24.0, 36.0, 48.0], + [24.0, 36.0, 48.0], + [24.0, 36.0, 48.0], + [24.0, 36.0, 48.0], + ], + [ + [36.0, 48.0, 60.0], + [36.0, 48.0, 60.0], + [36.0, 48.0, 60.0], + [36.0, 48.0, 60.0], + ], + ], + dtype=torch.float32, + ), # expected + ), + ( + "zero_weights_bias", + torch.tensor( + [ + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + [1.0, 2.0, 3.0, 4.0, 5.0], + ] + ], + dtype=torch.float32, + ), # src: 1x4x5 + torch.tensor( + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 0.1, # w_scale + torch.tensor([0, 0, 0, 0], dtype=torch.int8), # bias: 4 + 1.0, # b_scale + torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ], + dtype=torch.float32, + ), # expected + ), + ( + "negative_weights", + torch.tensor( + [ + [ + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + [2.0, 4.0, 6.0, 8.0, 10.0], + ] + ], + dtype=torch.float32, + ), # src: 1x4x5 + torch.tensor( + [ + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + [[-2, -1, 0], [-2, -1, 0], [-2, -1, 0], [-2, -1, 0]], + ], + dtype=torch.int8, + ), # weight: 4x4x3 + 0.5, # w_scale + torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4 + 1.0, # b_scale + torch.tensor( + [ + [ + [-14.0, -26.0, -38.0], + [-14.0, -26.0, -38.0], + [-14.0, -26.0, -38.0], + [-14.0, -26.0, -38.0], + ] + ], + dtype=torch.float32, + ), # expected + ), + ] + ) + def test_quantized_w8a32_conv( + self, + name: str, + src: torch.Tensor, + weight: torch.Tensor, + w_scale: float, + bias: torch.Tensor, + b_scale: float, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.quantized_w8a32_conv( + src, weight, w_scale, bias, b_scale + ) + + # Verify output properties + self.assertEqual( + output.dtype, + torch.float32, + f"Output dtype should be float32 in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + @expand( [ # Test case 1: Basic int8 case with negative scale