Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,3 +1416,159 @@ def im2row_per_tensor(
torch.tensor(in_zero_point, dtype=torch.int32),
channel_last,
)


@impl(m, "transposed_im2row")
def transposed_im2row(
input_tensor: torch.Tensor,
kernel_size: tuple[int, int],
dilation: tuple[int, int],
padding: tuple[int, int],
stride: tuple[int, int],
output_padding: tuple[int, int],
in_zero_point: torch.Tensor,
channel_last: bool = False,
) -> torch.Tensor:
"""
Converts input tensor patches into im2row format for transposed convolutions.
This function extracts patches from input in a pattern suitable for transposed convolution.

Args:
- input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
- kernel_size: Size of the convolution kernel.
- dilation: Dilation of the convolution kernel.
- padding: Padding to apply to the input.
- stride: Stride of the convolution.
- output_padding: Additional output padding for transposed convolution.
- in_zero_point: Zero point for input quantization (broadcastable to input).
- channel_last: If True, input is in NHWC format, else NCHW.

Returns:
- 3D tensor of shape (N, output_h * output_w, kernel_h * kernel_w * in_c)
"""
# Handle 1D convolution case by adding height dimension
if len(input_tensor.shape) == 3:
height_dim = 1 if channel_last else 2
input_tensor = input_tensor.unsqueeze(height_dim)

if in_zero_point is not None:
if in_zero_point.dtype != torch.int32:
raise ValueError("Input zero point must be an int32 tensor")

# Move to NCHW for processing if needed
if channel_last:
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW

N, C, H_in, W_in = input_tensor.shape

# Output: (N, C*H_in*W_in, H_out, W_out)
H_out = (
(H_in - 1) * stride[0]
+ kernel_size[0]
+ output_padding[0]
- 2 * padding[0]
+ dilation[0] * (kernel_size[0] - 1)
)
W_out = (
(W_in - 1) * stride[1]
+ kernel_size[1]
+ output_padding[1]
- 2 * padding[1]
+ dilation[1] * (kernel_size[1] - 1)
)

# For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
# Output: (N, C*H_in*W_in, H_out, W_out)
inp_flat = input_tensor.reshape(N, C * H_in * W_in)

# Calculate output spatial size
H_out = (
(H_in - 1) * stride[0]
- 2 * padding[0]
+ dilation[0] * (kernel_size[0] - 1)
+ output_padding[0]
+ 1
)
W_out = (
(W_in - 1) * stride[1]
- 2 * padding[1]
+ dilation[1] * (kernel_size[1] - 1)
+ output_padding[1]
+ 1
)

# Compute the upsampled (top-left) position for each input pixel
h_idx = torch.arange(H_in, device=input_tensor.device)
w_idx = torch.arange(W_in, device=input_tensor.device)
grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
out_h_idx = grid_h * stride[0] - padding[0]
out_w_idx = grid_w * stride[1] - padding[1]

# Compute all input pixel positions (flattened)
ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device)
ij_idx = ch_idx % (H_in * W_in)
i_idx = ij_idx // W_in
j_idx = ij_idx % W_in

# For each input pixel, compute the output positions for the kernel window
kh_idx = torch.arange(kernel_size[0], device=input_tensor.device)
kw_idx = torch.arange(kernel_size[1], device=input_tensor.device)
kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij")
kh_grid = kh_grid.reshape(-1)
kw_grid = kw_grid.reshape(-1)
num_kernel = kernel_size[0] * kernel_size[1]

# Broadcast to all channels and kernel positions
ch_idx_b = ch_idx.repeat_interleave(num_kernel)
n_kernel = ch_idx.shape[0] * num_kernel

i_idx_b = i_idx.repeat_interleave(num_kernel)
j_idx_b = j_idx.repeat_interleave(num_kernel)
kh_b = kh_grid.repeat(ch_idx.shape[0])
kw_b = kw_grid.repeat(ch_idx.shape[0])

h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0]
w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1]

# Mask for valid output positions
valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out)

# Prepare indices for advanced indexing
n_idx = (
torch.arange(N, device=input_tensor.device)
.view(-1, 1)
.expand(N, n_kernel)
.reshape(-1)
)
ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1)
h_out_full = h_out.expand(N, n_kernel).reshape(-1)
w_out_full = w_out.expand(N, n_kernel).reshape(-1)
valid_full = valid.expand(N, n_kernel).reshape(-1)

# Gather input values for each channel
inp_vals = inp_flat[:, ch_idx_b].reshape(-1)

# Create output tensor
patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype)

# If in_zero_point is provided, fill patches with it
if in_zero_point is not None:
if in_zero_point.numel() == 1:
patches.fill_(in_zero_point.item())
else:
# Broadcast in_zero_point to (N, C, H_in, W_in)
assert in_zero_point.shape == (N,)
in_zero_point = in_zero_point.view(N, 1, 1, 1)
patches = patches + in_zero_point

# Scatter input values to output positions (only valid positions)
patches[
n_idx[valid_full],
ch_idx_full[valid_full],
h_out_full[valid_full],
w_out_full[valid_full],
] = inp_vals[valid_full]

# Optionally, flatten to (N, num_patches, patch_size) if needed
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
return patches
170 changes: 170 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,3 +2136,173 @@ def test_im2row(
torch.equal(output, expected_output),
f"im2row output mismatch in {name}: got {output}, expected {expected_output}",
)

@expand(
[
(
"basic_2x2",
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32),
(2, 2),
(1, 1),
(0, 0),
(1, 1),
(0, 0),
None,
False,
torch.tensor(
[
[
[1, 0, 0, 0],
[1, 2, 0, 0],
[0, 2, 0, 0],
[1, 0, 3, 0],
[1, 2, 3, 4],
[0, 2, 0, 4],
[0, 0, 3, 0],
[0, 0, 3, 4],
[0, 0, 0, 4],
]
],
dtype=torch.int32,
),
),
(
"basic_2x2_with_zero_point",
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32),
(2, 2),
(1, 1),
(0, 0),
(1, 1),
(0, 0),
torch.tensor(100, dtype=torch.int32),
False,
torch.tensor(
[
[
[1, 100, 100, 100],
[1, 2, 100, 100],
[100, 2, 100, 100],
[1, 100, 3, 100],
[1, 2, 3, 4],
[100, 2, 100, 4],
[100, 100, 3, 100],
[100, 100, 3, 4],
[100, 100, 100, 4],
]
],
dtype=torch.int32,
),
),
(
"basic_2x2_with_stride_2",
torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.int32),
(2, 2), # kernel size
(1, 1), # dilation
(0, 0), # padding
(2, 2), # stride
(0, 0), # output padding
None,
False,
torch.tensor(
[
[
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 2, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 3, 0],
[0, 0, 0, 4],
[0, 0, 0, 4],
[0, 0, 3, 0],
[0, 0, 3, 0],
[0, 0, 0, 4],
[0, 0, 0, 4],
]
],
dtype=torch.int32,
),
),
(
"batch2_with_batch2_zero_point",
torch.tensor(
[
[[[1, 2], [3, 4]]],
[[[5, 6], [7, 8]]],
],
dtype=torch.int32,
), # input: (2,1,2,2)
(2, 2), # kernel_size
(1, 1), # dilation
(0, 0), # padding
(1, 1), # stride
(0, 0), # output_padding
torch.tensor([100, 200], dtype=torch.int32), # in_zero_point per batch
False, # channel_last
torch.tensor(
[
[
[1, 100, 100, 100],
[1, 2, 100, 100],
[100, 2, 100, 100],
[1, 100, 3, 100],
[1, 2, 3, 4],
[100, 2, 100, 4],
[100, 100, 3, 100],
[100, 100, 3, 4],
[100, 100, 100, 4],
],
[
[5, 200, 200, 200],
[5, 6, 200, 200],
[200, 6, 200, 200],
[5, 200, 7, 200],
[5, 6, 7, 8],
[200, 6, 200, 8],
[200, 200, 7, 200],
[200, 200, 7, 8],
[200, 200, 200, 8],
],
],
dtype=torch.int32,
),
),
]
)
def test_transposed_im2row(
self,
name: str,
input_tensor: torch.Tensor,
kernel_size: tuple[int, int],
dilation: tuple[int, int],
padding: tuple[int, int],
stride: tuple[int, int],
output_padding: tuple[int, int],
in_zero_point: torch.Tensor | int | None,
channel_last: bool,
expected_output: torch.Tensor,
) -> None:
output = torch.ops.cadence.transposed_im2row(
input_tensor,
kernel_size,
dilation,
padding,
stride,
output_padding,
in_zero_point,
channel_last,
)

self.assertEqual(
output.shape,
expected_output.shape,
f"transposed_im2row output shape mismatch in {name}: got {output.shape}, expected {expected_output.shape}",
)
self.assertTrue(
torch.equal(output, expected_output),
f"transposed_im2row output mismatch in {name}: got {output}, expected {expected_output}",
)
Loading