diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index ca15e825ff0..886cb14d0d6 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1303,3 +1303,116 @@ def rope( [x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1 ) return rotated.view(original_shape) + + +@impl(m, "im2row") +def im2row( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + in_zero_point: torch.Tensor, + channel_last: bool = False, +) -> torch.Tensor: + """ + Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch) + from the input, suitable for use in convolution as a matrix multiplication (im2row). + + Args: + - input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last. + - kernel_size: Size of the convolution kernel. + - dilation: Dilation of the convolution kernel. + - padding: Padding to apply to the input. + - stride: Stride of the convolution. + - in_zero_point : Zero point for input quantization (broadcastable to input). + - channel_last: If True, input is in NHWC format, else NCHW. + + Returns: + - Tensor of shape (N, num_patches, patch_size) + """ + if len(input_tensor.shape) == 3: + height_dim = 1 if channel_last else 2 + input_tensor = input_tensor.unsqueeze(height_dim) + + if in_zero_point is not None: + if in_zero_point.numel() != 1 and in_zero_point.shape != ( + input_tensor.shape[0], + ): + raise ValueError( + f"Input zero point must be a scalar or broadcastable to input shape {input_tensor.shape}" + ) + if in_zero_point.dtype != torch.int32: + raise ValueError("Input zero point must be an int32 tensor") + + if channel_last: + input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW + + N, C, H, W = input_tensor.shape + kH, kW = kernel_size + dH, dW = dilation + pH, pW = padding + sH, sW = stride + + # Handle padding with zero point values + if in_zero_point is not None and (pH > 0 or pW > 0): + # Expand zero point to (N, 1, 1, 1) for broadcasting + in_zero_point = in_zero_point.expand(N) + + # Pad input with the per-batch zero point values + input_tensor = torch.stack( + [ + torch.nn.functional.pad( + input_tensor[i], + (pW, pW, pH, pH), + mode="constant", + value=in_zero_point[i].item(), + ) + for i in range(len(input_tensor)) + ] + ) + + padding = (0, 0) # Already padded manually + + # Use unfold to extract sliding local blocks + # Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows + # torch.nn.functional.unfold returns (N, C*kH*kW, L) + patches = torch.nn.functional.unfold( + input_tensor.float(), # unfold not implemented for int + kernel_size=(kH, kW), + dilation=(dH, dW), + padding=padding, + stride=(sH, sW), + ).to( + input_tensor.dtype + ) # (N, C*kH*kW, L) + + # Transpose to (N, L, C*kH*kW) + patches = patches.transpose(1, 2).contiguous() + + # Reshape to (N*L, C*kH*kW) + patches = patches.view(N, -1, C * kH * kW) + + # If channel_last, output should be in NHWC patch order (but im2row is always row-major) + return patches + + +@impl(m, "im2row.per_tensor") +def im2row_per_tensor( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + in_zero_point: int, + channel_last: bool = False, +) -> torch.Tensor: + return im2row( + input_tensor, + kernel_size, + dilation, + padding, + stride, + torch.tensor(in_zero_point, dtype=torch.int32), + channel_last, + ) diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 8d02c5c2963..0aa1f0a243a 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1843,3 +1843,296 @@ def test_avg_pool2d( torch.equal(output, expected_output), f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + + @expand( + [ + # Basic 2x2 kernel, stride 1, no padding, NCHW + ( + "nchw_basic_2x2", + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32 + ), # (N=1, C=1, H=3, W=3) + (2, 2), # kernel_size + (1, 1), # dilation + (0, 0), # padding + (1, 1), # stride + None, # in_zero_point + False, # channel_last + False, + torch.tensor( + [ + [[1, 2, 4, 5], [2, 3, 5, 6], [4, 5, 7, 8], [5, 6, 8, 9]], + ], + dtype=torch.float32, + ), + ), + # 2x2 kernel, stride 2, no padding, NCHW + ( + "nchw_stride2", + torch.tensor( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32 + ), + (2, 2), + (1, 1), + (0, 0), + (2, 2), + None, + False, + False, + torch.tensor( + [ + [[1, 2, 4, 5]], + ], + dtype=torch.float32, # Only every other patch in each dim + ), + ), + # 2x2 kernel, stride 1, padding 1, NCHW + ( + "nchw_padding1", + torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.float32), # (1,1,2,2) + (2, 2), + (1, 1), + (1, 1), + (1, 1), + None, + False, + False, + torch.tensor( + [ + [ + [0, 0, 0, 1], + [0, 0, 1, 2], + [0, 0, 2, 0], + [0, 1, 0, 3], + [1, 2, 3, 4], + [2, 0, 4, 0], + [0, 3, 0, 0], + [3, 4, 0, 0], + [4, 0, 0, 0], + ], + ], + dtype=torch.float32, + ), + ), + # 2x2 kernel, stride 1, no padding, NHWC + ( + "nhwc_basic_2x2", + torch.tensor( + [[[[1], [2], [3]], [[4], [5], [6]], [[7], [8], [9]]]], + dtype=torch.float32, + ), # (N=1, H=3, W=3, C=1) + (2, 2), + (1, 1), + (0, 0), + (1, 1), + None, + True, + False, + torch.tensor( + [ + [[1, 2, 4, 5], [2, 3, 5, 6], [4, 5, 7, 8], [5, 6, 8, 9]], + ], + dtype=torch.float32, + ), + ), + # 2x2 kernel, stride 1, no padding, NCHW, in_zero_point=1 + ( + "nchw_in_zero_point_no_padding", + torch.tensor([[[[2, 3, 4], [5, 6, 7], [8, 9, 10]]]], dtype=torch.int8), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + torch.tensor(1, dtype=torch.int32), + False, + False, + torch.tensor( + [ + [[2, 3, 5, 6], [3, 4, 6, 7], [5, 6, 8, 9], [6, 7, 9, 10]], + ], + dtype=torch.int8, + ), + ), + ( + "nchw_in_zero_point_with_padding=1_and_stride=2", + torch.tensor([[[[2, 3, 4], [5, 6, 7], [8, 9, 10]]]], dtype=torch.int8), + (2, 2), + (1, 1), + (1, 1), + (2, 2), + torch.tensor(-1, dtype=torch.int32), + False, + False, + torch.tensor( + [ + [ + [-1, -1, -1, 2], + [-1, -1, 3, 4], + [-1, 5, -1, 8], + [6, 7, 9, 10], + ], + ], + dtype=torch.int8, + ), + ), + # 2x2 kernel, stride 1, no padding, NHWC, in_zero_point=2 + ( + "nhwc_in_zero_point", + torch.tensor( + [[[[3], [4], [5]], [[6], [7], [8]], [[9], [10], [11]]]], + dtype=torch.int8, + ), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + torch.tensor(2, dtype=torch.int32), + True, + False, + torch.tensor( + [ + [[3, 4, 6, 7], [4, 5, 7, 8], [6, 7, 9, 10], [7, 8, 10, 11]], + ], + dtype=torch.int8, + ), + ), + # Multi-channel input, 2x2 kernel, stride 1, no padding, NCHW + ( + "nchw_multi_channel", + torch.tensor( + [ + [ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], # channel 0 + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], # channel 1 + ] + ], + dtype=torch.float32, + ), # (1,2,3,3) + (2, 2), + (1, 1), + (0, 0), + (1, 1), + None, + False, + False, + torch.tensor( + [ + [ + [1, 2, 4, 5, 10, 11, 13, 14], + [2, 3, 5, 6, 11, 12, 14, 15], + [4, 5, 7, 8, 13, 14, 16, 17], + [5, 6, 8, 9, 14, 15, 17, 18], + ], + ], + dtype=torch.float32, + ), + ), + # Multi-channel input and multi-channel zero-point + ( + "nchw_multi_channel_and_zero_point_no_padding", + torch.tensor([[[1, 2, 3]], [[4, 5, 6]]], dtype=torch.int32), + (1, 2), + (1, 1), + (0, 0), + (1, 1), + torch.tensor([-1, -2], dtype=torch.int32), + False, + False, + torch.tensor([[[1, 2], [2, 3]], [[4, 5], [5, 6]]], dtype=torch.int32), + ), + ( + "nchw_multi_channel_and_zero_point_with_padding=1_and_stride=(2, 1)", + torch.tensor([[[1, 2, 3]], [[4, 5, 6]]], dtype=torch.int32), + (1, 2), + (1, 1), + (2, 1), + (2, 2), + torch.tensor([-1, -2], dtype=torch.int32), + False, + False, + torch.tensor( + [ + [ + [-1, -1], + [-1, -1], + [-1, 1], + [2, 3], + [-1, -1], + [-1, -1], + ], + [ + [-2, -2], + [-2, -2], + [-2, 4], + [5, 6], + [-2, -2], + [-2, -2], + ], + ], + dtype=torch.int32, + ), + ), + ( + "per_tensor", + torch.tensor( + [[[[3], [4], [5]], [[6], [7], [8]], [[9], [10], [11]]]], + dtype=torch.int8, + ), + (2, 2), + (1, 1), + (0, 0), + (1, 1), + 2, + True, + True, + torch.tensor( + [ + [[3, 4, 6, 7], [4, 5, 7, 8], [6, 7, 9, 10], [7, 8, 10, 11]], + ], + dtype=torch.int8, + ), + ), + ] + ) + def test_im2row( + self, + name: str, + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + dilation: tuple[int, int], + padding: tuple[int, int], + stride: tuple[int, int], + in_zero_point: torch.Tensor | None, + channel_last: bool, + per_tensor: bool, + expected_output: torch.Tensor, + ) -> None: + if per_tensor: + output = torch.ops.cadence.im2row.per_tensor( + input_tensor, + kernel_size, + dilation, + padding, + stride, + in_zero_point, + channel_last, + ) + else: + output = torch.ops.cadence.im2row( + input_tensor, + kernel_size, + dilation, + padding, + stride, + in_zero_point, + channel_last, + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"im2row output shape mismatch in {name}", + ) + self.assertTrue( + torch.equal(output, expected_output), + f"im2row output mismatch in {name}: got {output}, expected {expected_output}", + )