diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 52776c55c54..b45023c2808 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -933,6 +933,51 @@ def quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... def quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "convolution") +def convolution( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + channel_last: bool = False, +) -> torch.Tensor: + conv_is_1d = len(input_tensor.shape) == 3 + if channel_last: + if conv_is_1d: + input_tensor = input_tensor.movedim(-1, 1).contiguous() + if len(weight.shape) != 3: + raise ValueError("Weight tensor must be 3D if input is 3D") + weight = weight.movedim(-1, 1).contiguous() + else: + input_tensor = input_tensor.movedim(-1, -3) + if len(weight.shape) != 4: + raise ValueError("Weight tensor must be 4D if input is nd > 3") + weight = torch.permute(weight, (0, -1, 1, 2)).contiguous() + + _stride: tuple[int, int] | int = stride + _padding: tuple[int, int] | int = padding + _dilation: tuple[int, int] | int = dilation + if conv_is_1d: + conv = torch.nn.functional.conv1d + _stride = stride[0] + _padding = padding[0] + _dilation = dilation[0] + else: + conv = torch.nn.functional.conv2d + + conv_out = conv(input_tensor, weight, bias, _stride, _padding, _dilation, groups) + if channel_last: + if conv_is_1d: + conv_out = conv_out.movedim(1, -1).contiguous() + else: + conv_out = conv_out.movedim(-3, -1).contiguous() + + return conv_out + + def quantized_relu_common( X: torch.Tensor, X_zero_point: torch.Tensor | int, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 2858f9781e5..606be9098d6 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1256,3 +1256,280 @@ def test_rope( torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + + @expand( + [ + # Test case 1: Basic 2D convolution (NCHW format) + ( + "basic_2d_nchw", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 (identity-like filter) + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor( + [[[[5.0]]]], dtype=torch.float32 + ), # expected: 1*1 + 4*1 = 5 + ), + # Test case 2: Basic 2D convolution (NHWC format) + ( + "basic_2d_nhwc", + torch.tensor( + [[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32 + ), # input: 1x2x2x1 (NHWC) + torch.tensor( + [[[[1.0], [0.0]], [[0.0], [1.0]]]], dtype=torch.float32 + ), # weight: 1x2x2x1 (NHWC format) + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + True, # channel_last + torch.tensor( + [[[[5.0]]]], dtype=torch.float32 + ), # expected: 1*1 + 4*1 = 5 + ), + # Test case 3: 2D convolution with stride=2 + ( + "conv2d_stride2", + torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype=torch.float32, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 (sum filter) + torch.tensor([0.0], dtype=torch.float32), # bias + (2, 2), # stride=2 + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor([[[[14.0, 22.0], [46.0, 54.0]]]], dtype=torch.float32), + ), + # Test case 4: 2D convolution with padding=1 + ( + "conv2d_padding1", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (1, 1), # padding=1 + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor( + [[[[1.0, 2.0, 0.0], [3.0, 5.0, 2.0], [0.0, 3.0, 4.0]]]], + dtype=torch.float32, + ), # expected with padding + ), + # Test case 5: 2D convolution with dilation=2 + ( + "conv2d_dilation2", + torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype=torch.float32, + ), # input: 1x1x4x4 + torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (2, 2), # dilation=2 + 1, # groups + False, # channel_last + torch.tensor([[[[24.0, 28.0], [40.0, 44.0]]]], dtype=torch.float32), + ), + # Test case 6: 2D grouped convolution (groups=2) + ( + "conv2d_groups2", + torch.tensor( + [ + [ + [[1.0, 2.0], [3.0, 4.0]], # first input channel + [[5.0, 6.0], [7.0, 8.0]], # second input channel + ] + ], + dtype=torch.float32, + ), # input: 1x2x2x2 + torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]]], # first group weight + [[[0.5, 0.5], [0.5, 0.5]]], # second group weight + ], + dtype=torch.float32, + ), # weight: 2x1x2x2 + torch.tensor([0.0, 1.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 2, # groups=2 + False, # channel_last + torch.tensor([[[[10.0]], [[14.0]]]], dtype=torch.float32), + ), + # Test case 7: 1D convolution (NCL format) + ( + "conv1d_ncl", + torch.tensor( + [[[1.0, 2.0, 3.0, 4.0]]], dtype=torch.float32 + ), # input: 1x1x4 + torch.tensor([[[1.0, 1.0]]], dtype=torch.float32), # weight: 1x1x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride (only stride[1] is used for 1D) + (0, 0), # padding (only padding[1] is used for 1D) + (1, 1), # dilation (only dilation[1] is used for 1D) + 1, # groups + False, # channel_last + torch.tensor( + [[[3.0, 5.0, 7.0]]], dtype=torch.float32 + ), # expected: [1+2, 2+3, 3+4] + ), + # Test case 8: 1D convolution (NLC format) + ( + "conv1d_nlc", + torch.tensor( + [[[1.0], [2.0], [3.0], [4.0]]], dtype=torch.float32 + ), # input: 1x4x1 (NLC) + torch.tensor( + [[[1.0], [1.0]]], dtype=torch.float32 + ), # weight: 1x2x1 (NLC) + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + True, # channel_last + torch.tensor([[[3.0], [5.0], [7.0]]], dtype=torch.float32), + ), + # Test case 9: Multi-channel input and output + ( + "multi_channel", + torch.tensor( + [ + [ + [[1.0, 2.0], [3.0, 4.0]], # first input channel + [[0.5, 1.0], [1.5, 2.0]], # second input channel + ] + ], + dtype=torch.float32, + ), # input: 1x2x2x2 + torch.tensor( + [ + [ # first output channel + [[1.0, 0.0], [0.0, 1.0]], # weights for first input channel + [ + [2.0, 0.0], + [0.0, 2.0], + ], # weights for second input channel + ], + [ # second output channel + [[0.5, 0.5], [0.5, 0.5]], # weights for first input channel + [ + [1.0, 1.0], + [1.0, 1.0], + ], # weights for second input channel + ], + ], + dtype=torch.float32, + ), # weight: 2x2x2x2 + torch.tensor([0.0, 1.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor([[[[10.0]], [[11.0]]]], dtype=torch.float32), + ), + # Test case 10: Convolution with non-zero bias + ( + "conv2d_with_bias", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([10.0], dtype=torch.float32), # bias=10 + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + False, # channel_last + torch.tensor( + [[[[15.0]]]], dtype=torch.float32 + ), # expected: 5 + 10 = 15 + ), + ] + ) + def test_convolution( + self, + name: str, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.convolution( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + channel_last, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + )