In [45]:
import jax
import jax.numpy as jnp
from jax import lax, vmap

@jax.jit
def apply_shifted_interpolation_patch_alternative(patch_data, dx, dy):
    # patch_data: (H, W, C)
    H, W, C = patch_data.shape

    # --- Horizontal Convolution ---
    # Reshape input: (H, W, C) -> (H * C, W, 1) to treat each row of each channel as a batch item
    lhs_x_reshaped = patch_data.transpose(0, 2, 1).reshape(H * C, W, 1) # (H, C, W) -> (H*C, W, 1)

    # Kernel for 1D conv: (kernel_width, in_channels, out_channels)
    # Here, in_channels=1, out_channels=1 for the reshaped input
    rhs_x_1d = kernel_x_coeffs[:, None, None] # Shape (2, 1, 1)

    # Dimension numbers for (N, Spatial, C) input and (Spatial, InC, OutC) kernel
    # N is H*C, Spatial is W, C is 1
    dimension_numbers_1d = ('NWC', 'WIO', 'NWC')

    intermediate_x_reshaped = lax.conv_general_dilated(
        lhs=lhs_x_reshaped,
        rhs=rhs_x_1d,
        window_strides=(1,), # 1D stride
        padding='VALID',
        dimension_numbers=dimension_numbers_1d,
        feature_group_count=1 # Standard conv on this 1-channel group
    ) # Output shape: (H * C, W-1, 1)

    # Reshape back to (H, W-1, C)
    intermediate_x = intermediate_x_reshaped.reshape(H, C, W-1).transpose(0, 2, 1) # (H, C, W-1) -> (H, W-1, C)


    # --- Vertical Convolution ---
    # Reshape input: (H, W-1, C) -> (W-1 * C, H, 1)
    lhs_y_reshaped = intermediate_x.transpose(1, 2, 0).reshape((W-1) * C, H, 1) # (W-1, C, H) -> ((W-1)*C, H, 1)

    # Kernel for 1D conv: (kernel_height, in_channels, out_channels)
    rhs_y_1d = kernel_y_coeffs[:, None, None] # Shape (2, 1, 1)

    output_reshaped = lax.conv_general_dilated(
        lhs=lhs_y_reshaped,
        rhs=rhs_y_1d,
        window_strides=(1,), # 1D stride
        padding='VALID',
        dimension_numbers=dimension_numbers_1d, # NWC still works, N=(W-1)*C, Spatial=H, C=1
        feature_group_count=1
    ) # Output shape: ((W-1) * C, H-1, 1)

    # Reshape back to (H-1, W-1, C)
    final_output = output_reshaped.reshape(W-1, C, H-1).transpose(2, 0, 1) # (W-1, C, H-1) -> (H-1, W-1, C)

    return final_output

# --- Example Usage (same as before) ---
num_patches = 2
patch_size = 4
in_channels = 3

patches_data_batch = jnp.arange(num_patches * patch_size * patch_size * in_channels, dtype=jnp.float32).reshape(num_patches, patch_size, patch_size, in_channels)
dx_batch = jnp.array([0.2, 0.7])
dy_batch = jnp.array([0.3, 0.8])

interpolated_patches_batch_alt = vmap(apply_shifted_interpolation_patch_alternative, in_axes=(0, 0, 0))(
    patches_data_batch, dx_batch, dy_batch
)

print("Input patch shape (single):", patches_data_batch[0].shape)
print("Output interpolated patch shape (single, alternative):", interpolated_patches_batch_alt[0].shape)
print("Batch output shape (alternative):", interpolated_patches_batch_alt.shape)

# Verify results are identical
# print(jnp.allclose(interpolated_patches_batch, interpolated_patches_batch_alt)) # Should be True

Input patch shape (single): (4, 4, 3)
Output interpolated patch shape (single): (3, 3, 3)
Batch output shape: (2, 3, 3, 3)
