From 300cc667027fafbdd47e3dbeb4fc8e279a3edc11 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Fri, 3 Oct 2025 14:45:59 -0700 Subject: [PATCH] Add transposed im2row (#14738) Summary: Continued support for custom cadence ops. Reviewed By: hsharma35, eigen-k Differential Revision: D83709868 --- backends/cadence/aot/ref_implementations.py | 156 ++++++++++++++++ .../aot/tests/test_ref_implementations.py | 170 ++++++++++++++++++ 2 files changed, 326 insertions(+) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 886cb14d0d6..2642340679e 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1416,3 +1416,159 @@ def im2row_per_tensor( torch.tensor(in_zero_point, dtype=torch.int32), channel_last, ) + + +@impl(m, "transposed_im2row") +def transposed_im2row( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + output_padding: tuple[int, int], + in_zero_point: torch.Tensor, + channel_last: bool = False, +) -> torch.Tensor: + """ + Converts input tensor patches into im2row format for transposed convolutions. + This function extracts patches from input in a pattern suitable for transposed convolution. + + Args: + - input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D). + - kernel_size: Size of the convolution kernel. + - dilation: Dilation of the convolution kernel. + - padding: Padding to apply to the input. + - stride: Stride of the convolution. + - output_padding: Additional output padding for transposed convolution. + - in_zero_point: Zero point for input quantization (broadcastable to input). + - channel_last: If True, input is in NHWC format, else NCHW. + + Returns: + - 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c) + """ + # Handle 1D convolution case by adding height dimension + if len(input_tensor.shape) == 3: + height_dim = 1 if channel_last else 2 + input_tensor = input_tensor.unsqueeze(height_dim) + + if in_zero_point is not None: + if in_zero_point.dtype != torch.int32: + raise ValueError("Input zero point must be an int32 tensor") + + # Move to NCHW for processing if needed + if channel_last: + input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW + + N, C, H_in, W_in = input_tensor.shape + + # Output: (N, C*H_in*W_in, H_out, W_out) + H_out = ( + (H_in - 1) * stride[0] + + kernel_size[0] + + output_padding[0] + - 2 * padding[0] + + dilation[0] * (kernel_size[0] - 1) + ) + W_out = ( + (W_in - 1) * stride[1] + + kernel_size[1] + + output_padding[1] + - 2 * padding[1] + + dilation[1] * (kernel_size[1] - 1) + ) + + # For each input pixel, create a channel where the upsampled (transposed conv) patch is placed + # Output: (N, C*H_in*W_in, H_out, W_out) + inp_flat = input_tensor.reshape(N, C * H_in * W_in) + + # Calculate output spatial size + H_out = ( + (H_in - 1) * stride[0] + - 2 * padding[0] + + dilation[0] * (kernel_size[0] - 1) + + output_padding[0] + + 1 + ) + W_out = ( + (W_in - 1) * stride[1] + - 2 * padding[1] + + dilation[1] * (kernel_size[1] - 1) + + output_padding[1] + + 1 + ) + + # Compute the upsampled (top-left) position for each input pixel + h_idx = torch.arange(H_in, device=input_tensor.device) + w_idx = torch.arange(W_in, device=input_tensor.device) + grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij") + out_h_idx = grid_h * stride[0] - padding[0] + out_w_idx = grid_w * stride[1] - padding[1] + + # Compute all input pixel positions (flattened) + ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device) + ij_idx = ch_idx % (H_in * W_in) + i_idx = ij_idx // W_in + j_idx = ij_idx % W_in + + # For each input pixel, compute the output positions for the kernel window + kh_idx = torch.arange(kernel_size[0], device=input_tensor.device) + kw_idx = torch.arange(kernel_size[1], device=input_tensor.device) + kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij") + kh_grid = kh_grid.reshape(-1) + kw_grid = kw_grid.reshape(-1) + num_kernel = kernel_size[0] * kernel_size[1] + + # Broadcast to all channels and kernel positions + ch_idx_b = ch_idx.repeat_interleave(num_kernel) + n_kernel = ch_idx.shape[0] * num_kernel + + i_idx_b = i_idx.repeat_interleave(num_kernel) + j_idx_b = j_idx.repeat_interleave(num_kernel) + kh_b = kh_grid.repeat(ch_idx.shape[0]) + kw_b = kw_grid.repeat(ch_idx.shape[0]) + + h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0] + w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1] + + # Mask for valid output positions + valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out) + + # Prepare indices for advanced indexing + n_idx = ( + torch.arange(N, device=input_tensor.device) + .view(-1, 1) + .expand(N, n_kernel) + .reshape(-1) + ) + ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1) + h_out_full = h_out.expand(N, n_kernel).reshape(-1) + w_out_full = w_out.expand(N, n_kernel).reshape(-1) + valid_full = valid.expand(N, n_kernel).reshape(-1) + + # Gather input values for each channel + inp_vals = inp_flat[:, ch_idx_b].reshape(-1) + + # Create output tensor + patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype) + + # If in_zero_point is provided, fill patches with it + if in_zero_point is not None: + if in_zero_point.numel() == 1: + patches.fill_(in_zero_point.item()) + else: + # Broadcast in_zero_point to (N, C, H_in, W_in) + assert in_zero_point.shape == (N,) + in_zero_point = in_zero_point.view(N, 1, 1, 1) + patches = patches + in_zero_point + + # Scatter input values to output positions (only valid positions) + patches[ + n_idx[valid_full], + ch_idx_full[valid_full], + h_out_full[valid_full], + w_out_full[valid_full], + ] = inp_vals[valid_full] + + # Optionally, flatten to (N, num_patches, patch_size) if needed + patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous() + return patches diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 0aa1f0a243a..f78d2292e7b 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2136,3 +2136,173 @@ def test_im2row( torch.equal(output, expected_output), f"im2row output mismatch in {name}: got {output}, expected {expected_output}", ) + + @expand( + [ + ( + "basic_2x2", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + (0, 0), + None, + False, + torch.tensor( + [ + [ + [1, 0, 0, 0], + [1, 2, 0, 0], + [0, 2, 0, 0], + [1, 0, 3, 0], + [1, 2, 3, 4], + [0, 2, 0, 4], + [0, 0, 3, 0], + [0, 0, 3, 4], + [0, 0, 0, 4], + ] + ], + dtype=torch.int32, + ), + ), + ( + "basic_2x2_with_zero_point", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + (0, 0), + torch.tensor(100, dtype=torch.int32), + False, + torch.tensor( + [ + [ + [1, 100, 100, 100], + [1, 2, 100, 100], + [100, 2, 100, 100], + [1, 100, 3, 100], + [1, 2, 3, 4], + [100, 2, 100, 4], + [100, 100, 3, 100], + [100, 100, 3, 4], + [100, 100, 100, 4], + ] + ], + dtype=torch.int32, + ), + ), + ( + "basic_2x2_with_stride_2", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32), + (2, 2), # kernel size + (1, 1), # dilation + (0, 0), # padding + (2, 2), # stride + (0, 0), # output padding + None, + False, + torch.tensor( + [ + [ + [1, 0, 0, 0], + [1, 0, 0, 0], + [0, 2, 0, 0], + [0, 2, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [0, 2, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 3, 0], + [0, 0, 0, 4], + [0, 0, 0, 4], + [0, 0, 3, 0], + [0, 0, 3, 0], + [0, 0, 0, 4], + [0, 0, 0, 4], + ] + ], + dtype=torch.int32, + ), + ), + ( + "batch2_with_batch2_zero_point", + torch.tensor( + [ + [[[1, 2], [3, 4]]], + [[[5, 6], [7, 8]]], + ], + dtype=torch.int32, + ), # input: (2,1,2,2) + (2, 2), # kernel_size + (1, 1), # dilation + (0, 0), # padding + (1, 1), # stride + (0, 0), # output_padding + torch.tensor([100, 200], dtype=torch.int32), # in_zero_point per batch + False, # channel_last + torch.tensor( + [ + [ + [1, 100, 100, 100], + [1, 2, 100, 100], + [100, 2, 100, 100], + [1, 100, 3, 100], + [1, 2, 3, 4], + [100, 2, 100, 4], + [100, 100, 3, 100], + [100, 100, 3, 4], + [100, 100, 100, 4], + ], + [ + [5, 200, 200, 200], + [5, 6, 200, 200], + [200, 6, 200, 200], + [5, 200, 7, 200], + [5, 6, 7, 8], + [200, 6, 200, 8], + [200, 200, 7, 200], + [200, 200, 7, 8], + [200, 200, 200, 8], + ], + ], + dtype=torch.int32, + ), + ), + ] + ) + def test_transposed_im2row( + self, + name: str, + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + output_padding: tuple[int, int], + in_zero_point: torch.Tensor | int | None, + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.transposed_im2row( + input_tensor, + kernel_size, + dilation, + padding, + stride, + output_padding, + in_zero_point, + channel_last, + ) + + self.assertEqual( + output.shape, + expected_output.shape, + f"transposed_im2row output shape mismatch in {name}: got {output.shape}, expected {expected_output.shape}", + ) + self.assertTrue( + torch.equal(output, expected_output), + f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}", + )