In [None]:
import torch
import torch.nn as nn

In [None]:
class Learned2DPositionalEmbedding(nn.Module):
    def __init__(self, num_patches, patch_size, embedding_dim):
        super(Learned2DPositionalEmbedding, self).__init__()

        # Calculate the number of rows and columns in the patch grid
        self.grid_size = int(num_patches ** 0.5)  # Assuming square grid

        # Initialize learned positional embeddings for rows and columns
        # Row embeddings (one per row)
        self.row_embeddings = nn.Parameter(torch.randn(self.grid_size, embedding_dim))
        # Column embeddings (one per column)
        self.col_embeddings = nn.Parameter(torch.randn(self.grid_size, embedding_dim))

    def forward(self, x):
        """
        x: Patch embeddings of shape (batch_size, num_patches, embedding_dim)
        """
        batch_size, num_patches, embedding_dim = x.shape

        # Reshape the patch indices to match the 2D grid
        row_indices = torch.arange(self.grid_size).repeat(self.grid_size, 1).view(-1)  # Shape: (num_patches,)
        col_indices = torch.arange(self.grid_size).repeat(self.grid_size, 1).t().contiguous().view(-1)  # Shape: (num_patches,)

        # Gather the row and column embeddings
        row_positional_embeddings = self.row_embeddings[row_indices]
        col_positional_embeddings = self.col_embeddings[col_indices]

        # Combine row and column positional embeddings by summing them
        positional_embeddings = row_positional_embeddings + col_positional_embeddings  # Shape: (num_patches, embedding_dim)

        # Add the positional embeddings to the patch embeddings
        return x + positional_embeddings.unsqueeze(0)  # Shape: (batch_size, num_patches, embedding_dim)
