Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,3 +1303,116 @@ def rope(
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
)
return rotated.view(original_shape)


@impl(m, "im2row")
def im2row(
input_tensor: torch.Tensor,
kernel_size: tuple[int, int],
dilation: tuple[int, int],
padding: tuple[int, int],
stride: tuple[int, int],
in_zero_point: torch.Tensor,
channel_last: bool = False,
) -> torch.Tensor:
"""
Converts an input tensor into a 2D matrix where each row is a flattened sliding window (patch)
from the input, suitable for use in convolution as a matrix multiplication (im2row).

Args:
- input_tensor: Input tensor of shape (N, C, H, W) or (N, H, W, C) if channel_last.
- kernel_size: Size of the convolution kernel.
- dilation: Dilation of the convolution kernel.
- padding: Padding to apply to the input.
- stride: Stride of the convolution.
- in_zero_point : Zero point for input quantization (broadcastable to input).
- channel_last: If True, input is in NHWC format, else NCHW.

Returns:
- Tensor of shape (N, num_patches, patch_size)
"""
if len(input_tensor.shape) == 3:
height_dim = 1 if channel_last else 2
input_tensor = input_tensor.unsqueeze(height_dim)

if in_zero_point is not None:
if in_zero_point.numel() != 1 and in_zero_point.shape != (
input_tensor.shape[0],
):
raise ValueError(
f"Input zero point must be a scalar or broadcastable to input shape {input_tensor.shape}"
)
if in_zero_point.dtype != torch.int32:
raise ValueError("Input zero point must be an int32 tensor")

if channel_last:
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW

N, C, H, W = input_tensor.shape
kH, kW = kernel_size
dH, dW = dilation
pH, pW = padding
sH, sW = stride

# Handle padding with zero point values
if in_zero_point is not None and (pH > 0 or pW > 0):
# Expand zero point to (N, 1, 1, 1) for broadcasting
in_zero_point = in_zero_point.expand(N)

# Pad input with the per-batch zero point values
input_tensor = torch.stack(
[
torch.nn.functional.pad(
input_tensor[i],
(pW, pW, pH, pH),
mode="constant",
value=in_zero_point[i].item(),
)
for i in range(len(input_tensor))
]
)

padding = (0, 0) # Already padded manually

# Use unfold to extract sliding local blocks
# Unfold: (N, C, H, W) -> (N, C, L, kH, kW), where L = number of sliding windows
# torch.nn.functional.unfold returns (N, C*kH*kW, L)
patches = torch.nn.functional.unfold(
input_tensor.float(), # unfold not implemented for int
kernel_size=(kH, kW),
dilation=(dH, dW),
padding=padding,
stride=(sH, sW),
).to(
input_tensor.dtype
) # (N, C*kH*kW, L)

# Transpose to (N, L, C*kH*kW)
patches = patches.transpose(1, 2).contiguous()

# Reshape to (N*L, C*kH*kW)
patches = patches.view(N, -1, C * kH * kW)

# If channel_last, output should be in NHWC patch order (but im2row is always row-major)
return patches


@impl(m, "im2row.per_tensor")
def im2row_per_tensor(
input_tensor: torch.Tensor,
kernel_size: tuple[int, int],
dilation: tuple[int, int],
padding: tuple[int, int],
stride: tuple[int, int],
in_zero_point: int,
channel_last: bool = False,
) -> torch.Tensor:
return im2row(
input_tensor,
kernel_size,
dilation,
padding,
stride,
torch.tensor(in_zero_point, dtype=torch.int32),
channel_last,
)
Loading
Loading