diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 312bed89315..ca15e825ff0 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -960,6 +960,7 @@ def convolution( _stride: tuple[int, int] | int = stride _padding: tuple[int, int] | int = padding _dilation: tuple[int, int] | int = dilation + if conv_is_1d: conv = torch.nn.functional.conv1d _stride = stride[0] @@ -978,6 +979,64 @@ def convolution( return conv_out +@impl(m, "transposed_convolution") +def transposed_convolution( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + output_padding: tuple[int, int], + groups: int, + channel_last: bool = False, +) -> torch.Tensor: + + conv_is_1d = len(input_tensor.shape) == 3 + if channel_last: + if conv_is_1d: + input_tensor = input_tensor.movedim(-1, 1).contiguous() + if len(weight.shape) != 3: + raise ValueError("Weight tensor must be 3D if input is 3D") + weight = weight.movedim(-1, 1).contiguous() + else: + input_tensor = input_tensor.movedim(-1, -3) + if len(weight.shape) != 4: + raise ValueError("Weight tensor must be 4D if input is nd > 3") + weight = torch.permute(weight, (0, -1, 1, 2)).contiguous() + + _stride: tuple[int, int] | int = stride + _padding: tuple[int, int] | int = padding + _dilation: tuple[int, int] | int = dilation + _output_padding: tuple[int, int] | int = output_padding + if conv_is_1d: + conv = torch.nn.functional.conv_transpose1d + _stride = stride[0] + _padding = padding[0] + _dilation = dilation[0] + _output_padding = output_padding[0] + else: + conv = torch.nn.functional.conv_transpose2d + + conv_out = conv( + input_tensor, + weight, + bias, + _stride, + _padding, + _output_padding, + groups, + _dilation, + ) + if channel_last: + if conv_is_1d: + conv_out = conv_out.movedim(1, -1).contiguous() + else: + conv_out = conv_out.movedim(-3, -1).contiguous() + + return conv_out + + @impl(m, "avg_pool2d") def avg_pool2d( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 32e9b43e68e..8d02c5c2963 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1534,6 +1534,143 @@ def test_convolution( f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + @expand( + [ + # Basic 2D transposed convolution with stride=1 (current test case - corrected name) + ( + "basic_2d_stride1", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + (0, 0), # output_padding + False, # channel_last + torch.tensor( + [[[[1.0, 3.0, 2.0], [4.0, 10.0, 6.0], [3.0, 7.0, 4.0]]]], + dtype=torch.float32, + ), + ), + # 2D transposed convolution with channel_last=True (NHWC format) + ( + "channel_last_nhwc", + torch.tensor( + [[[[1.0], [2.0]], [[3.0], [4.0]]]], dtype=torch.float32 + ), # input: 1x2x2x1 (NHWC) + torch.tensor( + [[[[1.0], [1.0]], [[1.0], [1.0]]]], dtype=torch.float32 + ), # weight: 1x2x2x1 (NHWC) + torch.tensor([0.0], dtype=torch.float32), # bias + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + (0, 0), # output_padding + True, # channel_last=True + torch.tensor( + [ + [ + [[1.0], [3.0], [2.0]], + [[4.0], [10.0], [6.0]], + [[3.0], [7.0], [4.0]], + ] + ], + dtype=torch.float32, + ), + ), + # 2D transposed convolution with non-zero bias + ( + "with_bias", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + torch.tensor( + [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 + ), # weight: 1x1x2x2 + torch.tensor([5.0], dtype=torch.float32), # bias=5.0 + (1, 1), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + (0, 0), # output_padding + False, # channel_last + torch.tensor( + [[[[6.0, 7.0, 5.0], [8.0, 10.0, 7.0], [5.0, 8.0, 9.0]]]], + dtype=torch.float32, + ), + ), + # 1D transposed convolution (3D tensor, NLC format) + ( + "conv1d_nlc", + torch.tensor( + [[[1.0], [2.0], [3.0]]], dtype=torch.float32 + ), # input: 1x3x1 (NLC) + torch.tensor( + [[[1.0], [0.5]]], dtype=torch.float32 + ), # weight: 1x2x1 (NLC) + torch.tensor([0.0], dtype=torch.float32), # bias + (2, 0), # stride + (0, 0), # padding + (1, 1), # dilation + 1, # groups + (0, 0), # output_padding + True, # channel_last=True + torch.tensor( + [[[1.0], [0.5], [2.0], [1.0], [3.0], [1.5]]], dtype=torch.float32 + ), + ), + ] + ) + def test_transposed_convolution( + self, + name: str, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + output_padding: tuple[int, int], + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.transposed_convolution( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + output_padding, + groups, + channel_last, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + @expand( [ # Basic non-quantized average pooling