In [477]:
import sys

sys.path.append("../..")

In [478]:
from src.vendored.croco.models.croco import CroCoNet
from PIL import Image



In [479]:
# !mkdir -p pretrained_models/
# !wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/

In [480]:
from pathlib import Path

In [None]:
import torch
import torchvision.transforms
from torchvision.transforms import ToTensor, Normalize, Compose

device = torch.device("cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu")

# load 224x224 images and transform them to tensor
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1, 3, 1, 1).to(device, non_blocking=True)
imagenet_std = [0.229, 0.224, 0.225]
imagenet_std_tensor = torch.tensor(imagenet_std).view(1, 3, 1, 1).to(device, non_blocking=True)
trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
image1 = trfs(Image.open("../assets/Chateau1.png").convert("RGB")).to(device, non_blocking=True).unsqueeze(0)
image2 = trfs(Image.open("../assets/Chateau2.png").convert("RGB")).to(device, non_blocking=True).unsqueeze(0)

# load model
ckpt = torch.load(Path.cwd() / "pretrained_models/CroCo.pth", "cpu")
model = CroCoNet(**ckpt.get("croco_kwargs", {})).to(device)
model.eval()
msg = model.load_state_dict(ckpt["model"], strict=True)

# forward
with torch.inference_mode():
    out, mask, target = model(image1, image2)

# the output is normalized, thus use the mean/std of the actual image to go back to RGB space
patchified = model.patchify(image1)
mean = patchified.mean(dim=-1, keepdim=True)
var = patchified.var(dim=-1, keepdim=True)
decoded_image = model.unpatchify(out * (var + 1.0e-6) ** 0.5 + mean)
# undo imagenet normalization, prepare masked image
decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:, :, None])
masked_input_image = (1 - image_masks) * input_image

# make visualization
visualization = torch.cat(
    (ref_image, masked_input_image, decoded_image, input_image), dim=3
)  # 4*(B, 3, H, W) -> B, 3, H, W*4
B, C, H, W = visualization.shape
visualization = visualization.permute(1, 0, 2, 3).reshape(C, B * H, W)
visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
fname = "demo_output.png"
visualization.save(fname)
print("Visualization save in " + fname)

### Patchify/Unpatchify

#### Implementation

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import numpy as np


def patchify(imgs, patch_size):
    """
    Divide images into non-overlapping square patches.

    Parameters:
    - imgs (torch.Tensor): Input images of shape (B, C, H, W).
    - patch_size (int): Size of each square patch (p).

    Returns:
    - patches (torch.Tensor): Patches of shape (B, L, p^2 * C),
                              where L = (H // p) * (W // p).
    - num_patches_h (int): Number of patches along the height.
    - num_patches_w (int): Number of patches along the width.
    """
    B, C, H, W = imgs.shape
    assert H % patch_size == 0 and W % patch_size == 0, "Image dimensions must be divisible by the patch size."

    num_patches_h = H // patch_size
    num_patches_w = W // patch_size

    # Reshape to (B, C, num_patches_h, patch_size, num_patches_w, patch_size)
    x = imgs.reshape(B, C, num_patches_h, patch_size, num_patches_w, patch_size)

    # Permute to (B, num_patches_h, num_patches_w, patch_size, patch_size, C)
    x = x.permute(0, 2, 4, 3, 5, 1)

    # Reshape to (B, L, p^2 * C), where L = num_patches_h * num_patches_w
    patches = x.reshape(B, num_patches_h * num_patches_w, patch_size * patch_size * C)

    return patches, num_patches_h, num_patches_w


def unpatchify(patches, patch_size, num_patches_h, num_patches_w, channels=3):
    """
    Reconstruct images from patches.

    Parameters:
    - patches (torch.Tensor): Patches of shape (B, L, p^2 * C),
                              where L = (H // p) * (W // p).
    - patch_size (int): Size of each square patch (p).
    - num_patches_h (int): Number of patches along the height.
    - num_patches_w (int): Number of patches along the width.
    - channels (int): Number of channels in the image (default: 3).

    Returns:
    - imgs (torch.Tensor): Reconstructed images of shape (B, C, H, W).
    """
    B, L, patch_dim = patches.shape
    assert patch_dim % channels == 0, "Patch dimension is not compatible with the number of channels."

    p_squared = patch_dim // channels
    p = int(patch_size)
    assert p * p == p_squared, "Patch size does not match patch dimension."

    expected_L = num_patches_h * num_patches_w
    assert L == expected_L, f"Number of patches (L={L}) does not match num_patches_h * num_patches_w ({expected_L})."

    # Reshape to (B, num_patches_h, num_patches_w, patch_size, patch_size, C)
    x = patches.reshape(B, num_patches_h, num_patches_w, patch_size, patch_size, channels)

    # Permute to (B, C, num_patches_h, patch_size, num_patches_w, patch_size)
    x = x.permute(0, 5, 1, 3, 2, 4)

    # Reshape to (B, C, H, W), where H = num_patches_h * patch_size, W = num_patches_w * patch_size
    H = num_patches_h * patch_size
    W = num_patches_w * patch_size
    imgs = x.reshape(B, channels, H, W)

    return imgs

#### Tests

In [None]:
def test_patchify_unpatchify_round_trip():
    """
    Test that unpatchify(patchify(imgs)) == imgs for random tensors.
    """
    print("Running Round-Trip Test...")
    B, C, H, W = 2, 3, 128, 128  # Square images
    patch_size = 32
    imgs = torch.randn(B, C, H, W)

    patches, num_patches_h, num_patches_w = patchify(imgs, patch_size)
    reconstructed = unpatchify(patches, patch_size, num_patches_h, num_patches_w, channels=C)

    assert torch.allclose(
        imgs, reconstructed, atol=1e-6
    ), "Round-trip patchify -> unpatchify failed for square images."
    print("Round-Trip Test Passed for Square Images.")


def test_patchify_unpatchify_specific_case():
    """
    Test patchify and unpatchify with a specific tensor where the outcome is known.
    """
    print("Running Specific Case Test...")
    # Create a simple tensor where each pixel value is unique
    B, C, H, W = 1, 1, 4, 6  # Non-square image
    patch_size = 2
    imgs = torch.arange(B * C * H * W).reshape(B, C, H, W).float()

    # Expected patches:
    # For H=4, W=6, patch_size=2 -> num_patches_h=2, num_patches_w=3
    # Patches are ordered row-wise
    # Patch 1: [[0, 1], [6, 7]]
    # Patch 2: [[2, 3], [8, 9]]
    # Patch 3: [[4, 5], [10,11]]
    # Patch 4: [[12,13], [18,19]]
    # Patch 5: [[14,15], [20,21]]
    # Patch 6: [[16,17], [22,23]]
    expected_patches = torch.tensor(
        [
            [
                [0.0, 1.0, 6.0, 7.0],
                [2.0, 3.0, 8.0, 9.0],
                [4.0, 5.0, 10.0, 11.0],
                [12.0, 13.0, 18.0, 19.0],
                [14.0, 15.0, 20.0, 21.0],
                [16.0, 17.0, 22.0, 23.0],
            ]
        ]
    )  # Shape: (1, 6, 4)

    patches, num_patches_h, num_patches_w = patchify(imgs, patch_size)
    assert torch.allclose(patches, expected_patches, atol=1e-6), "Patchify specific case failed."
    print("Patchify Specific Case Test Passed.")

    # Now test unpatchify
    reconstructed = unpatchify(patches, patch_size, num_patches_h, num_patches_w, channels=C)
    assert torch.allclose(imgs, reconstructed, atol=1e-6), "Unpatchify specific case failed."
    print("Unpatchify Specific Case Test Passed.")


def test_patchify_unpatchify_non_square():
    """
    Test patchify and unpatchify with non-square images.
    """
    print("Running Non-Square Image Test...")
    B, C, H, W = 1, 3, 64, 32  # Non-square image
    patch_size = 16
    imgs = torch.randn(B, C, H, W)

    patches, num_patches_h, num_patches_w = patchify(imgs, patch_size)
    reconstructed = unpatchify(patches, patch_size, num_patches_h, num_patches_w, channels=C)

    assert torch.allclose(imgs, reconstructed, atol=1e-6), "Patchify -> Unpatchify failed for non-square images."
    print("Non-Square Image Test Passed.")


def test_invalid_inputs():
    """
    Test that the functions handle invalid inputs gracefully.
    """
    print("Running Invalid Input Test...")
    # Image dimensions not divisible by patch_size
    B, C, H, W = 1, 3, 65, 33  # Not divisible by 16
    patch_size = 16
    imgs = torch.randn(B, C, H, W)

    try:
        patches, num_patches_h, num_patches_w = patchify(imgs, patch_size)
    except AssertionError as e:
        print(f"Properly caught invalid input: {e}")
    else:
        assert False, "Failed to catch invalid input where H and W are not divisible by patch_size."

    # Unpatchify with incorrect number of patches
    B, C, H, W = 1, 3, 64, 32
    patch_size = 16
    imgs = torch.randn(B, C, H, W)
    patches, num_patches_h, num_patches_w = patchify(imgs, patch_size)

    # Tamper with patches
    patches_tampered = patches[:, :-1, :]

    try:
        reconstructed = unpatchify(patches_tampered, patch_size, num_patches_h, num_patches_w, channels=C)
    except AssertionError as e:
        print(f"Properly caught unpatchify with incorrect number of patches: {e}")
    else:
        assert False, "Failed to catch unpatchify with incorrect number of patches."

    print("Invalid Input Test Passed.")


def run_all_tests():
    """
    Run all defined tests.
    """
    test_patchify_unpatchify_round_trip()
    test_patchify_unpatchify_specific_case()
    test_patchify_unpatchify_non_square()
    test_invalid_inputs()
    print("All Tests Passed Successfully.")


run_all_tests()

### Patch Embed

#### Implementation

In [None]:
from typing import Dict, Tuple
import torch
import torch.nn as nn


class PositionGetter:
    """
    Generates and caches patch position encodings for image patches.

    This class creates position encodings for image patches in a grid layout,
    caching them for reuse to improve efficiency when the same grid dimensions
    are requested multiple times.
    """

    def __init__(self):
        """Initialize an empty cache for position encodings."""
        self.cache_positions: Dict[Tuple[int, int], torch.Tensor] = {}

    def __call__(self, batch_size: int, num_patches_h: int, num_patches_w: int, device: torch.device) -> torch.Tensor:
        """
        Generate or retrieve cached position encodings for the specified patch grid.

        Parameters:
            batch_size (int): Number of samples in the batch
            num_patches_h (int): Number of patches in height dimension
            num_patches_w (int): Number of patches in width dimension
            device (torch.device): Device to place the tensors on

        Returns:
            torch.Tensor: Position encodings of shape (batch_size, num_patches_h * num_patches_w, 2)
        """
        if (num_patches_h, num_patches_w) not in self.cache_positions:
            self.cache_positions[num_patches_h, num_patches_w] = self._generate_patch_positions_with_dimension(
                num_patches_h, num_patches_w, device
            )
        return self._expand_patch_positions_with_batch_size(
            self.cache_positions[num_patches_h, num_patches_w], batch_size
        )

    def _generate_patch_positions_with_dimension(
        self, num_patches_h: int, num_patches_w: int, device: torch.device
    ) -> torch.Tensor:
        """
        Generate position encodings for a single sample.

        Parameters:
            num_patches_h (int): Number of patches in height dimension
            num_patches_w (int): Number of patches in width dimension
            device (torch.device): Device to place the tensors on

        Returns:
            torch.Tensor: Position encodings of shape (num_patches_h * num_patches_w, 2)
        """
        y = torch.arange(num_patches_h, device=device)
        x = torch.arange(num_patches_w, device=device)
        positions_for_patch = torch.cartesian_prod(y, x)
        return positions_for_patch

    def _expand_patch_positions_with_batch_size(self, patch_positions: torch.Tensor, batch_size: int) -> torch.Tensor:
        """
        Expand position encodings to match batch size.

        Parameters:
            patch_positions (torch.Tensor): Position encodings for a single sample
            batch_size (int): Number of samples in the batch

        Returns:
            torch.Tensor: Position encodings of shape (batch_size, num_patches_h * num_patches_w, 2)
        """
        num_positions = patch_positions.size(0)
        unsqueezed_patch_positions = patch_positions.unsqueeze(0)
        expanded_patch_positions_with_batch_size = unsqueezed_patch_positions.expand(batch_size, num_positions, 2)
        return expanded_patch_positions_with_batch_size


class PatchEmbed(nn.Module):
    """
    Embeds image patches into a specified embedding dimension.

    This module splits an image into non-overlapping patches, projects each patch
    into a high-dimensional space using a convolutional layer, applies normalization,
    and provides positional encodings for each patch.
    """

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_channels: int = 3,
        embed_dim: int = 768,
        norm_layer: nn.Module = None,
        flatten: bool = True,
    ):
        """
        Initialize the PatchEmbed module.

        Parameters:
            img_size (int): Size of the input image (assumed square)
            patch_size (int): Size of each patch (assumed square)
            in_channels (int): Number of input image channels
            embed_dim (int): Dimension of the patch embeddings
            norm_layer (nn.Module, optional): Normalization layer constructor
            flatten (bool): If True, flatten spatial dimensions after projection
        """
        super().__init__()

        # Check compatibility of norm_layer and flatten settings
        if not flatten and norm_layer is not None and norm_layer != nn.Identity:
            raise ValueError(
                "LayerNorm cannot be used with flatten=False. "
                "When flatten=False, the output shape is (B, embed_dim, H, W) "
                "which is incompatible with LayerNorm's expected shape (B, H*W, embed_dim)."
            )

        self.img_size = (img_size, img_size)
        self.patch_size = (patch_size, patch_size)

        self.num_patches_h = self.img_size[0] // self.patch_size[0]
        self.num_patches_w = self.img_size[1] // self.patch_size[1]
        self.num_patches = self.num_patches_h * self.num_patches_w

        self.flatten = flatten

        # Project patches to embedding dimension
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)

        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        self.position_getter = PositionGetter()

        self._init_weights()

    def _init_weights(self):
        """Initialize the weights using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.proj.weight)
        if self.proj.bias is not None:
            nn.init.zeros_(self.proj.bias)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Process input images through the patch embedding module.

        Parameters:
            x (torch.Tensor): Input images of shape (B, C, H, W)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - Patch embeddings of shape:
                    - If flatten=True: (B, num_patches, embed_dim)
                    - If flatten=False: (B, embed_dim, num_patches_h, num_patches_w)
                - Position encodings of shape (B, num_patches, 2)
        """
        B, C, H, W = x.shape

        assert (
            H == self.img_size[0]
        ), f"Input image height ({H}) doesn't match model's expected height ({self.img_size[0]})."
        assert (
            W == self.img_size[1]
        ), f"Input image width ({W}) doesn't match model's expected width ({self.img_size[1]})."

        x = self.proj(x)  # (B, embed_dim, num_patches_h, num_patches_w)

        pos_encodings = self.position_getter(
            batch_size=B, num_patches_h=self.num_patches_h, num_patches_w=self.num_patches_w, device=x.device
        )  # (B, num_patches, 2)

        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        x = self.norm(x)

        return x, pos_encodings

#### Tests

In [None]:
import torch
import torch.nn as nn
import numpy as np


def test_position_getter():
    """Test the PositionGetter class."""
    print("Testing PositionGetter...")

    # Test initialization
    pos_getter = PositionGetter()
    assert isinstance(pos_getter.cache_positions, dict), "Cache should be a dictionary"
    assert len(pos_getter.cache_positions) == 0, "Cache should be empty at initialization"
    print("✓ Initialization test passed")

    # Test position generation and caching
    batch_size = 2
    num_patches_h = 3
    num_patches_w = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Generate positions
    positions = pos_getter(batch_size, num_patches_h, num_patches_w, device)

    # Check shape
    assert positions.shape == (
        batch_size,
        num_patches_h * num_patches_w,
        2,
    ), f"Expected shape ({batch_size}, {num_patches_h * num_patches_w}, 2), got {positions.shape}"

    # Check cache
    cache_key = (num_patches_h, num_patches_w)
    assert cache_key in pos_getter.cache_positions, "Positions should be cached"
    cached_positions = pos_getter.cache_positions[cache_key]
    assert cached_positions.shape == (
        num_patches_h * num_patches_w,
        2,
    ), "Cached positions should have shape (num_patches_h * num_patches_w, 2)"

    # Check coordinate values
    assert torch.all(positions[0, 0] == torch.tensor([0, 0], device=device)), "First position should be (0, 0)"
    assert torch.all(
        positions[0, -1] == torch.tensor([num_patches_h - 1, num_patches_w - 1], device=device)
    ), f"Last position should be ({num_patches_h-1}, {num_patches_w-1})"
    print("✓ Position generation test passed")

    # Test cache reuse
    positions2 = pos_getter(batch_size, num_patches_h, num_patches_w, device)
    assert torch.equal(positions, positions2), "Cached positions should be identical"
    assert len(pos_getter.cache_positions) == 1, "Cache should have only one entry"
    print("✓ Cache reuse test passed")

    print("All PositionGetter tests passed!\n")


def test_patch_embed():
    """Test the PatchEmbed class."""
    print("Testing PatchEmbed...")

    # Test valid initialization with flatten=True and LayerNorm
    patch_embed = PatchEmbed(
        img_size=224, patch_size=16, in_channels=3, embed_dim=768, norm_layer=nn.LayerNorm, flatten=True
    )
    print("✓ Valid initialization with flatten=True and LayerNorm passed")

    # Test initialization with flatten=False and no norm
    patch_embed_no_norm = PatchEmbed(
        img_size=224, patch_size=16, in_channels=3, embed_dim=768, norm_layer=None, flatten=False
    )
    print("✓ Valid initialization with flatten=False and no norm passed")

    # Test initialization with flatten=False and Identity norm
    patch_embed_identity = PatchEmbed(
        img_size=224, patch_size=16, in_channels=3, embed_dim=768, norm_layer=nn.Identity, flatten=False
    )
    print("✓ Valid initialization with flatten=False and Identity norm passed")

    # Test that LayerNorm with flatten=False raises error
    try:
        patch_embed_invalid = PatchEmbed(
            img_size=224, patch_size=16, in_channels=3, embed_dim=768, norm_layer=nn.LayerNorm, flatten=False
        )
        assert False, "Should have raised ValueError"
    except ValueError as e:
        assert "LayerNorm cannot be used with flatten=False" in str(
            e
        ), "Should raise specific error about LayerNorm incompatibility"
    print("✓ LayerNorm incompatibility check passed")

    # Test forward pass with flatten=True
    batch_size = 4
    x = torch.randn(batch_size, 3, 224, 224)

    embeddings, pos_encodings = patch_embed(x)
    assert embeddings.shape == (
        batch_size,
        196,
        768,
    ), f"Expected embeddings shape (4, 196, 768), got {embeddings.shape}"
    assert pos_encodings.shape == (
        batch_size,
        196,
        2,
    ), f"Expected positions shape (4, 196, 2), got {pos_encodings.shape}"
    print("✓ Forward pass with flatten=True test passed")

    # Test forward pass with flatten=False
    embeddings, pos_encodings = patch_embed_no_norm(x)
    assert embeddings.shape == (
        batch_size,
        768,
        14,
        14,
    ), f"Expected embeddings shape (4, 768, 14, 14), got {embeddings.shape}"
    assert pos_encodings.shape == (
        batch_size,
        196,
        2,
    ), f"Expected positions shape (4, 196, 2), got {pos_encodings.shape}"
    print("✓ Forward pass with flatten=False test passed")

    # Test invalid input size
    try:
        x_invalid = torch.randn(batch_size, 3, 256, 224)
        patch_embed(x_invalid)
        assert False, "Should have raised AssertionError"
    except AssertionError as e:
        assert "Input image height" in str(e), "Should raise error about invalid input height"
    print("✓ Invalid input size check passed")

    # Test weight initialization
    def check_xavier_uniform(tensor):
        fan_in = tensor.size(1)
        fan_out = tensor.size(0)
        bound = np.sqrt(6.0 / (fan_in + fan_out))
        return torch.all(tensor >= -bound) and torch.all(tensor <= bound)

    w = patch_embed.proj.weight
    assert check_xavier_uniform(w.view(w.size(0), -1)), "Weights should follow Xavier uniform distribution"

    if patch_embed.proj.bias is not None:
        assert torch.all(patch_embed.proj.bias == 0), "Bias should be initialized to zero"
    print("✓ Weight initialization test passed")

    print("All PatchEmbed tests passed!\n")


test_position_getter()
test_patch_embed()

### Positional Embeddings

#### Implementation 1d

In [None]:
import numpy as np

def get_1d_sincos_pos_embed_from_grid(embed_dim: int, coordinate_grid: np.ndarray) -> np.ndarray:
    """
    Generate sinusoidal positional embeddings for 1D positions.

    This function creates embeddings where each dimension corresponds to a sinusoid
    of a different frequency. The first half uses sine, the second half uses cosine.

    Args:
        embed_dim: Dimension of the output embeddings (must be even)
        positions: Array of positions to encode, will be flattened

    Returns:
        np.ndarray: Position embeddings with shape [len(flattened_positions), embed_dim]
    """
    assert embed_dim % 2 == 0, "Embedding dimension must be even"

    # Generate frequency bands for the sinusoidal embeddings
    # Each position will be encoded with sinusoids of these frequencies
    frequency_bands = np.arange(embed_dim // 2, dtype=float)
    frequency_bands /= embed_dim / 2.0
    frequency_bands = 1.0 / 10000**frequency_bands  # Shape: (D/2,)

    # Flatten input positions
    flattened_positions = coordinate_grid.reshape(-1)  # Shape: (M,)

    # Compute position-frequency products
    # This creates a matrix where each row corresponds to a position and
    # each column corresponds to that position multiplied by a frequency
    phase_matrix = np.einsum("m,d->md", flattened_positions, frequency_bands)  # Shape: (M, D/2)

    # Generate sine and cosine embeddings
    sin_embeddings = np.sin(phase_matrix)  # Shape: (M, D/2)
    cos_embeddings = np.cos(phase_matrix)  # Shape: (M, D/2)

    # Combine sine and cosine embeddings
    # Shape: (M, D) where D = embed_dim
    combined_embeddings = np.concatenate([sin_embeddings, cos_embeddings], axis=1)

    return combined_embeddings

#### Visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

def visualize_embeddings(embeddings, coordinate_grid, plot_freqency_components=True):
    """
    Visualize how 1D sinusoidal embeddings encode positions.
    
    Args:
        embed_dim: The dimensionality of the embeddings
        coordinate_grid: Array of positions to encode, can be any shape
    """
    # Create heatmap for each spatial dimension
    n_dims = len(coordinate_grid.squeeze().shape)
    if n_dims == 1:
        plot_1d_heatmap(embeddings, coordinate_grid)
    elif n_dims == 2:
        plot_2d_heatmap(embeddings, coordinate_grid)
    else:
        print(f"Heatmap visualization not supported for {n_dims}D grids")
        
    # Show individual frequencies
    if plot_freqency_components:
        plot_frequency_components(embeddings, coordinate_grid)

def visualize_1d_embeddings(embeddings, coordinate_grid):
    """
    Visualize how 1D sinusoidal embeddings encode positions.
    
    Args:
        embeddings: The embeddings tensor
        coordinate_grid: Array of positions to encode, can be any shape
    """
    # Create heatmap based on input dimensionality
    if len(embeddings.shape) == 2:  # 1D case
        plot_1d_heatmap(embeddings, coordinate_grid)
    elif len(embeddings.shape) == 3:  # 2D case
        plot_2d_heatmap(embeddings, coordinate_grid)
    
    # Show individual frequencies
    plot_frequency_components(embeddings, coordinate_grid)

def plot_1d_heatmap(embeddings, coordinate_grid):
    """Plot heatmap for 1D coordinate grid."""
    plt.figure(figsize=(12, 4))
    sns.heatmap(embeddings, cmap="RdBu", center=0)
    print("Embeddings:")
    print(embeddings.round(2))
    plt.title("1D Position Embeddings")
    plt.xlabel("Embedding Dimension")
    plt.ylabel("Position")
    
    # Add position values to y-axis
    positions = coordinate_grid.flatten()
    plt.yticks(np.arange(len(positions)) + 0.5, 
               labels=positions.round(2), 
               rotation=0)
    plt.show()

def plot_2d_heatmap(embeddings, coordinate_grid):
    """Plot heatmaps for 2D coordinate grid."""
    h, w, embed_dim = embeddings.shape
    n_cols = min(4, embed_dim)
    n_rows = (embed_dim + n_cols - 1) // n_cols
    
    plt.figure(figsize=(4*n_cols, 4*n_rows))
    for i in range(embed_dim):
        plt.subplot(n_rows, n_cols, i+1)
        # Generate tick positions
        x_ticks = np.linspace(0, w-1, min(w, 7))  # Limit to 7 ticks for readability
        y_ticks = np.linspace(0, h-1, min(h, 7))
        
        sns.heatmap(embeddings[:, :, i].reshape(h, w), 
                   cmap="RdBu", 
                   center=0,
                   xticklabels=[f"{j+1:.1f}" for j in x_ticks],
                   yticklabels=[f"{j+1:.1f}" for j in y_ticks])
        plt.title(f"Dimension {i}")
        if i % n_cols == 0:
            plt.ylabel("Y Position")
        if i >= embed_dim - n_cols:
            plt.xlabel("X Position")
    plt.tight_layout()
    plt.show()

def plot_frequency_components(embeddings, coordinate_grid):
    """Plot individual frequency components with dimensionality handling."""
    # Check input dimensionality
    if len(embeddings.shape) == 2:  # 1D case
        n_pos, embed_dim = embeddings.shape
        plt.figure(figsize=(12, 4))
        for i in range(embed_dim // 2):
            plt.plot(coordinate_grid, 
                    embeddings[:, i], 
                    label=f"sin_dim_{i}", 
                    linestyle="-")
            plt.plot(coordinate_grid, 
                    embeddings[:, i + embed_dim // 2], 
                    label=f"cos_dim_{i}", 
                    linestyle="--")
        plt.title("Sinusoidal Components")
        plt.xlabel("Position")
        plt.ylabel("Value")
        plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        plt.tight_layout()
        plt.show()
        
    else:  # 2D case
        h, w, embed_dim = embeddings.shape
        # Create x and y coordinate arrays
        x_coords = np.linspace(0, w-1, w)
        y_coords = np.linspace(0, h-1, h)
        
        # Create subplots for both x and y variations
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
        
        # Plot variation along x-axis (middle row of embeddings)
        mid_y = h // 2
        for i in range(embed_dim // 2):
            ax1.plot(x_coords, 
                    embeddings[mid_y, :, i], 
                    label=f"sin_dim_{i}", 
                    linestyle="-")
            ax1.plot(x_coords, 
                    embeddings[mid_y, :, i + embed_dim // 2], 
                    label=f"cos_dim_{i}", 
                    linestyle="--")
        ax1.set_title("Sinusoidal Components Along X-axis (at middle Y)")
        ax1.set_xlabel("X Position")
        ax1.set_ylabel("Value")
        ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        
        # Plot variation along y-axis (middle column of embeddings)
        mid_x = w // 2
        for i in range(embed_dim // 2):
            ax2.plot(y_coords, 
                    embeddings[:, mid_x, i], 
                    label=f"sin_dim_{i}", 
                    linestyle="-")
            ax2.plot(y_coords, 
                    embeddings[:, mid_x, i + embed_dim // 2], 
                    label=f"cos_dim_{i}", 
                    linestyle="--")
        ax2.set_title("Sinusoidal Components Along Y-axis (at middle X)")
        ax2.set_xlabel("Y Position")
        ax2.set_ylabel("Value")
        ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
        
        plt.tight_layout()
        plt.show()

In [None]:
coordinate_grid = np.arange(0, 2*np.pi).squeeze()
embeddings = get_1d_sincos_pos_embed_from_grid(embed_dim=2, coordinate_grid=coordinate_grid)
visualize_embeddings(embeddings, coordinate_grid)

In [None]:
# coordinate_grid = np.arange(0, 2*np.pi).squeeze()
# embeddings = get_1d_sincos_pos_embed_from_grid(embed_dim=8, coordinate_grid=coordinate_grid)
# visualize_embeddings(embeddings, coordinate_grid)

In [None]:
# coordinate_grid = np.arange(0, 2*np.pi).squeeze()
# embeddings = get_1d_sincos_pos_embed_from_grid(embed_dim=32, coordinate_grid=coordinate_grid)
# visualize_embeddings(embeddings, coordinate_grid)

The unique encoding given to each position is given by the values for all the dimensions for that given position. Naturally, this can be extended to the 2d case as shown below

#### Implementation 2d

In [None]:
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, coordinate_grid: np.ndarray) -> np.ndarray:
    """
    Convert a 2D coordinate grid into sinusoidal positional embeddings.

    Args:
        embed_dim: Total embedding dimension (must be even)
        coordinate_grid: Grid of coordinates with shape [2, 1, H, W]

    Returns:
        np.ndarray: Positional embeddings with shape [H*W, embed_dim]
    """
    assert embed_dim % 2 == 0, "Embedding dimension must be even"

    # Split dimension evenly between height and width coordinates
    dim_per_coordinate = embed_dim // 2

    # Generate embeddings separately for height and width coordinates
    height_embeddings = get_1d_sincos_pos_embed_from_grid(
        dim_per_coordinate, coordinate_grid=coordinate_grid[0]
    )  # Shape: (H*W, D/2)
    width_embeddings = get_1d_sincos_pos_embed_from_grid(
        dim_per_coordinate, coordinate_grid=coordinate_grid[1]
    )  # Shape: (H*W, D/2)

    # Combine height and width embeddings
    # Shape: (H*W, D) where D = dim_per_coordinate * 2 = embed_dim
    combined_embeddings = np.concatenate([height_embeddings, width_embeddings], axis=1)

    return combined_embeddings

def construct_coordinate_grid_2d(grid_size_x, grid_size_y):
    coordinate_grid_x = np.arange(grid_size_x)
    coordinate_grid_y = np.arange(grid_size_y)   
    grid_h, grid_w = np.meshgrid(coordinate_grid_y, coordinate_grid_x)
    coordinate_grid_2d = np.stack([grid_h, grid_w], axis=0)
    coordinate_grid_2d = coordinate_grid_2d.reshape([2, 1, grid_size_y, grid_size_x])
    return coordinate_grid_2d

In [None]:
# Grid size determines the domain of the positional embeddings
# Embed dim determines the number of frequencies of the positional embeddings
# Extend embeddings to 2d, but really it is just made up of reshaped 1d embeddings so instead of (N, D) where N=number of tokens, and D=embed_dim
# it is now (H * W, D) where H = number of tokens in H, and W = number of tokens in W)
grid_size = int(2 * np.pi) + 1
grid_size_x = grid_size
grid_size_y = grid_size
coordinate_grid_2d = construct_coordinate_grid_2d(grid_size_x=grid_size_x, grid_size_y=grid_size_y)

In [None]:
# Lets see first the H embeddings
embed_dim = 2
coordinate_grid_h = coordinate_grid_2d[0]
coordinate_h, coordinate_w = coordinate_grid.squeeze().shape
embeddings = get_1d_sincos_pos_embed_from_grid(embed_dim=embed_dim, coordinate_grid=coordinate_grid_h)
reshaped_embeddings = embeddings.reshape(coordinate_h, coordinate_w, embed_dim)
visualize_embeddings(reshaped_embeddings, coordinate_grid)

In [None]:
# Then the W embeddings
embed_dim = 2
coordinate_grid_w = coordinate_grid_2d[1]
coordinate_h, coordinate_w = coordinate_grid.squeeze().shape
embeddings = get_1d_sincos_pos_embed_from_grid(embed_dim=embed_dim, coordinate_grid=coordinate_grid_w)
reshaped_embeddings = embeddings.reshape(coordinate_h, coordinate_w, embed_dim)
visualize_embeddings(reshaped_embeddings, coordinate_grid)

In [None]:
# WLOG we can always reshape a 1d embedding into 2d thus get_2d_sincos_pos_embed_from_grid operates will reshape the embeddings from (H, W, embed_dim) -> (H * W, embed_dim)
embed_dim_1d = 2
embeddings = get_2d_sincos_pos_embed_from_grid(embed_dim_1d * 2, coordinate_grid_2d)

In [None]:
assert embeddings.shape == (grid_size_y * grid_size_x, embed_dim_1d * 2)