Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
507 changes: 264 additions & 243 deletions backends/cadence/aot/fuse_ops.py

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,6 +2439,11 @@ def transposed_im2row_meta(
in_zero_point: torch.Tensor,
channel_last: bool = False,
) -> torch.Tensor:
"""
Shape inference for transposed_im2row operation.

Returns shape: (N, H_out * W_out, K_h * K_w * C_in)
"""
if len(input.shape) == 3:
height_dim = 1 if channel_last else 2
input = input.unsqueeze(height_dim)
Expand All @@ -2447,6 +2452,8 @@ def transposed_im2row_meta(
n_input_plane = input.shape[3] if channel_last else input.shape[1]
input_height = input.shape[1] if channel_last else input.shape[2]
input_width = input.shape[2] if channel_last else input.shape[3]

# Calculate output spatial dimensions
output_height = (
(input_height - 1) * stride[0]
- 2 * padding[0]
Expand All @@ -2461,9 +2468,11 @@ def transposed_im2row_meta(
+ output_padding[1]
+ 1
)
n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
output_length = output_height * output_width
output_size = torch.Size((batch_size, output_length, n_output_plane))

# Patch size is kernel_h * kernel_w * in_channels
patch_size = kernel_size[0] * kernel_size[1] * n_input_plane
num_patches = output_height * output_width
output_size = torch.Size((batch_size, num_patches, patch_size))

return input.new_empty(output_size, dtype=input.dtype)

Expand Down
223 changes: 129 additions & 94 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,22 @@ def transposed_convolution(
channel_last: bool = False,
) -> torch.Tensor:

# Cadence transposed conv receives weights that have been transformed by the pass:
# 1. Transposed (dims 0 and 1 swapped): [out_channels, in_channels, *kernel]
# 2. Flipped (spatial dimensions reversed)
# We need to reverse both transformations to call PyTorch's conv_transpose

conv_is_1d = len(input_tensor.shape) == 3

# Determine flip dimensions based on weight dimensionality
weight_dim = len(weight.shape)
flip_dims = [-1] if weight_dim == 3 else [-1, -2]

# Reverse transformation step 1: Unflip the spatial dimensions
weight = torch.flip(weight, dims=flip_dims)

# Reverse transformation step 2: Transpose back to PyTorch format [in, out, *kernel]
weight = weight.transpose(0, 1).contiguous()
if channel_last:
if conv_is_1d:
input_tensor = input_tensor.movedim(-1, 1).contiguous()
Expand Down Expand Up @@ -1856,12 +1871,13 @@ def transposed_im2row(
channel_last: bool = False,
) -> torch.Tensor:
"""
Converts input tensor patches into im2row format for transposed convolutions.
This function extracts patches from input in a pattern suitable for transposed convolution.
Converts input tensor into im2row format for transposed convolutions.
For each output position, extracts the kernel-sized patch of input values that
contribute to that position in a transposed convolution.

Args:
- input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
- kernel_size: Size of the convolution kernel.
- kernel_size: Size of the convolution kernel (kernel_h, kernel_w).
- dilation: Dilation of the convolution kernel.
- padding: Padding to apply to the input.
- stride: Stride of the convolution.
Expand All @@ -1886,117 +1902,136 @@ def transposed_im2row(
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW

N, C, H_in, W_in = input_tensor.shape

# Output: (N, C*H_in*W_in, H_out, W_out)
H_out = (
(H_in - 1) * stride[0]
+ kernel_size[0]
+ output_padding[0]
- 2 * padding[0]
+ dilation[0] * (kernel_size[0] - 1)
)
W_out = (
(W_in - 1) * stride[1]
+ kernel_size[1]
+ output_padding[1]
- 2 * padding[1]
+ dilation[1] * (kernel_size[1] - 1)
)

# For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
# Output: (N, C*H_in*W_in, H_out, W_out)
inp_flat = input_tensor.reshape(N, C * H_in * W_in)
K_h, K_w = kernel_size
device = input_tensor.device

# Calculate output spatial size
H_out = (
(H_in - 1) * stride[0]
- 2 * padding[0]
+ dilation[0] * (kernel_size[0] - 1)
+ dilation[0] * (K_h - 1)
+ output_padding[0]
+ 1
)
W_out = (
(W_in - 1) * stride[1]
- 2 * padding[1]
+ dilation[1] * (kernel_size[1] - 1)
+ dilation[1] * (K_w - 1)
+ output_padding[1]
+ 1
)

# Compute the upsampled (top-left) position for each input pixel
h_idx = torch.arange(H_in, device=input_tensor.device)
w_idx = torch.arange(W_in, device=input_tensor.device)
grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
out_h_idx = grid_h * stride[0] - padding[0]
out_w_idx = grid_w * stride[1] - padding[1]

# Compute all input pixel positions (flattened)
ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device)
ij_idx = ch_idx % (H_in * W_in)
i_idx = ij_idx // W_in
j_idx = ij_idx % W_in

# For each input pixel, compute the output positions for the kernel window
kh_idx = torch.arange(kernel_size[0], device=input_tensor.device)
kw_idx = torch.arange(kernel_size[1], device=input_tensor.device)
kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij")
kh_grid = kh_grid.reshape(-1)
kw_grid = kw_grid.reshape(-1)
num_kernel = kernel_size[0] * kernel_size[1]

# Broadcast to all channels and kernel positions
ch_idx_b = ch_idx.repeat_interleave(num_kernel)
n_kernel = ch_idx.shape[0] * num_kernel

i_idx_b = i_idx.repeat_interleave(num_kernel)
j_idx_b = j_idx.repeat_interleave(num_kernel)
kh_b = kh_grid.repeat(ch_idx.shape[0])
kw_b = kw_grid.repeat(ch_idx.shape[0])

h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0]
w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1]

# Mask for valid output positions
valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out)

# Prepare indices for advanced indexing
n_idx = (
torch.arange(N, device=input_tensor.device)
.view(-1, 1)
.expand(N, n_kernel)
.reshape(-1)
)
ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1)
h_out_full = h_out.expand(N, n_kernel).reshape(-1)
w_out_full = w_out.expand(N, n_kernel).reshape(-1)
valid_full = valid.expand(N, n_kernel).reshape(-1)

# Gather input values for each channel
inp_vals = inp_flat[:, ch_idx_b].reshape(-1)

# Create output tensor
patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype)
# Create meshgrids for all output positions and kernel positions
h_out_grid = torch.arange(H_out, device=device).view(
-1, 1, 1, 1
) # [H_out, 1, 1, 1]
w_out_grid = torch.arange(W_out, device=device).view(
1, -1, 1, 1
) # [1, W_out, 1, 1]
kh_grid = torch.arange(K_h, device=device).view(1, 1, -1, 1) # [1, 1, K_h, 1]
kw_grid = torch.arange(K_w, device=device).view(1, 1, 1, -1) # [1, 1, 1, K_w]

# Compute input positions for all (h_out, w_out, kh, kw) combinations
# From C++ reference: h_im = _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h
h_im = h_out_grid - (K_h - 1) * dilation[0] + kh_grid * dilation[0] + padding[0]
w_im = w_out_grid - (K_w - 1) * dilation[1] + kw_grid * dilation[1] + padding[1]

# Check which positions are valid (divisible by stride and within bounds)
# From C++ reference: if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0)
h_valid = (h_im % stride[0] == 0) & (h_im >= 0) & (h_im < stride[0] * H_in)
w_valid = (w_im % stride[1] == 0) & (w_im >= 0) & (w_im < stride[1] * W_in)
valid = h_valid & w_valid # [H_out, W_out, K_h, K_w]

# Actual input indices (h_im / stride_h from C++ reference)
h_in = h_im // stride[0]
w_in = w_im // stride[1]

# Clamp indices to valid range (will be masked out anyway)
h_in_safe = h_in.clamp(0, H_in - 1)
w_in_safe = w_in.clamp(0, W_in - 1)

# Initialize output patches with zero points (vectorized across batches)
# Layout depends on channel_last: NHWC uses [K_h, K_w, C], NCHW uses [C, K_h, K_w]
if channel_last:
# NHWC: patches layout [N, H_out, W_out, K_h, K_w, C]
patches = torch.zeros(
(N, H_out, W_out, K_h, K_w, C),
dtype=input_tensor.dtype,
device=device,
)
else:
# NCHW: patches layout [N, H_out, W_out, C, K_h, K_w]
patches = torch.zeros(
(N, H_out, W_out, C, K_h, K_w),
dtype=input_tensor.dtype,
device=device,
)

# If in_zero_point is provided, fill patches with it
# Initialize patches with zero points (vectorized)
if in_zero_point is not None:
if in_zero_point.numel() == 1:
# Scalar zero point - fill all patches
patches.fill_(in_zero_point.item())
else:
# Broadcast in_zero_point to (N, C, H_in, W_in)
assert in_zero_point.shape == (N,)
in_zero_point = in_zero_point.view(N, 1, 1, 1)
patches = patches + in_zero_point

# Scatter input values to output positions (only valid positions)
patches[
n_idx[valid_full],
ch_idx_full[valid_full],
h_out_full[valid_full],
w_out_full[valid_full],
] = inp_vals[valid_full]

# Optionally, flatten to (N, num_patches, patch_size) if needed
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
# Per-batch zero points - expand and fill
# in_zero_point: [N] -> [N, 1, 1, 1, 1, 1] or [N, 1, 1, 1, 1, 1]
zp_shape = [N] + [1] * (patches.ndim - 1)
patches = patches + in_zero_point.view(*zp_shape)

# Flatten the spatial and kernel dimensions for efficient gathering
# h_in_safe, w_in_safe: [H_out, W_out, K_h, K_w] (broadcast shape)
h_flat = h_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1)
w_flat = w_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1)

# Vectorized gathering across all batches and channels using advanced indexing
# Create index tensors with appropriate broadcasting shapes
num_positions = h_flat.shape[0]

# batch_indices: [N, 1, 1] -> broadcasts to [N, C, num_positions]
batch_indices = torch.arange(N, device=device).view(N, 1, 1)

# channel_indices: [1, C, 1] -> broadcasts to [N, C, num_positions]
channel_indices = torch.arange(C, device=device).view(1, C, 1)

# h_flat, w_flat: [1, 1, num_positions] -> broadcasts to [N, C, num_positions]
h_indices = h_flat.view(1, 1, num_positions)
w_indices = w_flat.view(1, 1, num_positions)

# Advanced indexing gathers all values at once: [N, C, num_positions]
gathered = input_tensor[batch_indices, channel_indices, h_indices, w_indices]

# Reshape based on channel_last flag
if channel_last:
# NHWC: Reshape to [N, H_out, W_out, K_h, K_w, C]
# gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, H_out*W_out*K_h*K_w, C] -> [N, H_out, W_out, K_h, K_w, C]
gathered = gathered.transpose(1, 2).contiguous() # [N, num_positions, C]
gathered = gathered.view(N, H_out, W_out, K_h, K_w, C)
else:
# NCHW: Reshape to [N, H_out, W_out, C, K_h, K_w]
# gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, C, H_out, W_out, K_h, K_w] -> [N, H_out, W_out, C, K_h, K_w]
gathered = gathered.view(N, C, H_out, W_out, K_h, K_w)
gathered = gathered.permute(0, 2, 3, 1, 4, 5).contiguous()

# Apply validity mask (vectorized across batches)
# valid: [H_out, W_out, K_h, K_w] -> expand to match gathered shape
if channel_last:
# gathered: [N, H_out, W_out, K_h, K_w, C]
valid_exp = valid.unsqueeze(0).unsqueeze(-1) # [1, H_out, W_out, K_h, K_w, 1]
else:
# gathered: [N, H_out, W_out, C, K_h, K_w]
valid_exp = valid.unsqueeze(0).unsqueeze(3) # [1, H_out, W_out, 1, K_h, K_w]

patches = torch.where(valid_exp, gathered, patches)

# Reshape to final output format: [N, H_out * W_out, K_h * K_w * C]
# The reshaping will preserve the correct dimension ordering
if channel_last:
# patches: [N, H_out, W_out, K_h, K_w, C] -> [N, H_out*W_out, K_h*K_w*C]
patches = patches.view(N, H_out * W_out, K_h * K_w * C)
else:
# patches: [N, H_out, W_out, C, K_h, K_w] -> [N, H_out*W_out, C*K_h*K_w]
patches = patches.view(N, H_out * W_out, C * K_h * K_w)

return patches


Expand Down
Loading
Loading