In [None]:
import torch
import math

In [None]:
import torch
import math

class FourierFeatureEmbedding2D(torch.nn.Module):
    def __init__(self, img_size, patch_size, num_frequencies=10, embedding_dim=768, scale=1.0):
        """
        Args:
            img_size (tuple): The size of the input image (height, width).
            patch_size (tuple): The size of each patch (height, width).
            num_frequencies (int): Number of frequency bands to use for sine/cosine.
            embedding_dim (int): The dimension of the patch embeddings (usually 768 or 1024 for ViT).
            scale (float): Scaling factor for the input coordinates.
        """
        super(FourierFeatureEmbedding2D, self).__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_frequencies = num_frequencies
        self.embedding_dim = embedding_dim
        self.scale = scale
        
        # Calculate number of patches along each axis (height and width)
        self.num_patches_y = img_size[0] // patch_size[0]
        self.num_patches_x = img_size[1] // patch_size[1]
        
        # Frequencies: logarithmic spacing between 0 and 2*pi
        self.frequencies = torch.logspace(0, math.log10(2 * math.pi), num_frequencies, base=10)

    def forward(self):
        """
        Generate the 2D Fourier feature positional embeddings for each patch.
        
        Returns:
            tensor: Positional embeddings for each patch (shape: [num_patches, embedding_dim])
        """
        # Generate patch grid coordinates
        y_coords = torch.arange(self.num_patches_y).float()
        x_coords = torch.arange(self.num_patches_x).float()
        
        # Create meshgrid of patch coordinates
        grid_y, grid_x = torch.meshgrid(y_coords, x_coords)
        
        # Normalize to range [0, 1] (optional scaling)
        grid_y = grid_y / (self.num_patches_y - 1)
        grid_x = grid_x / (self.num_patches_x - 1)
        
        # Stack coordinates into a 2D grid of shape [num_patches, 2]
        positions = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)  # Shape: [num_patches, 2]
        
        # Apply Fourier feature transformation
        embeddings = []

        for freq in self.frequencies:
            # Apply sine and cosine transformations to both x and y coordinates
            embeddings.append(torch.sin(freq * positions[:, 0]))
            embeddings.append(torch.cos(freq * positions[:, 0]))
            embeddings.append(torch.sin(freq * positions[:, 1]))
            embeddings.append(torch.cos(freq * positions[:, 1]))

        # Concatenate all sine/cosine embeddings, result will have shape [num_patches, 4 * num_frequencies]
        embeddings = torch.stack(embeddings, dim=-1)

        # Flatten the embeddings to match the desired embedding dimension
        embeddings = embeddings.view(self.num_patches_y * self.num_patches_x, -1)  # Shape: [num_patches, 4 * num_frequencies]

        return embeddings


If PE embedding dim does not match patch embedding dim -

<code>
# If the resulting embedding dimension does not match the patch embedding size, we can apply a linear transformation
if embeddings.size(1) != self.embedding_dim:
    # Linear transformation to match embedding dimension
    self.linear = torch.nn.Linear(embeddings.size(1), self.embedding_dim)
    embeddings = self.linear(embeddings)
</code>