# Rope Embeddings

In [None]:
import torch
import torch.nn.functional as F

class RoPEEmbedding(torch.nn.Module):
    def __init__(self, dim):
        """
        Rotary Positional Embedding (RoPE) implementation in PyTorch.
        
        Args:
        - dim: The embedding dimension (must be even).
        """
        super().__init__()
        assert dim % 2 == 0, "Embedding dimension must be even for RoPE."
        self.dim = dim

    def forward(self, x):
        """
        Applies RoPE to input tensor.

        Args:
        - x: Tensor of shape (batch_size, seq_len, dim)

        Returns:
        - Tensor with the same shape, but with RoPE applied.
        """
        seq_len = x.shape[1]
        theta = 10000 ** (-torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim)
        position = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)
        freqs = position * theta.unsqueeze(0)  # (seq_len, dim/2)

        # Compute sin and cos components
        sin = torch.sin(freqs).to(x.device)
        cos = torch.cos(freqs).to(x.device)

        # Reshape input tensor for rotation
        x1, x2 = x[..., ::2], x[..., 1::2]  # Even and odd parts

        # Apply rotation
        rotated_x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
        return rotated_x


In [4]:
batch_size = 2
seq_len = 5
dim = 16  # Must be even

rope = RoPEEmbedding(dim)
x = torch.randn(batch_size, seq_len, dim)  # Example input
output = rope(x)

print(output.shape)  # Should be (batch_size, seq_len, dim)


torch.Size([2, 5, 8]) torch.Size([2, 5, 8])
torch.Size([2, 5, 16])
