# Relative Positionla Embeddings (RoPE)

## Summary

<img src = "../images/RoPE.png" width = "100%">

**Rotary Posisional Embeddings or RoPE** represent a paradigm shift in sequence modeling by unifying absolute and relative positional information through geometric transformations.

The mathematical brilliance of RoPE lies in the **Dot Product Linearity**. When the self-attention mechanism calculates the score between a Query ($q$) at position $i$ and a Key ($k$) at position $j$, the result depends only on the relative angle between them: $\theta_{i} - \theta_{j}$.

Because the dot product of two rotated vectors is invariant to their absolute rotation and only sensitive to their relative displacement, the model "senses" how far apart two tokens are by the degree of rotation needed to align them.

## Step by Step Explanation

### Mathematical Foundation

The fundamental objective of RoPE is to encode position $i$ by rotating the Query ($q$) and Key ($k$) vectors in a manner that preserves their relative distance.

**The Rotation Mechanism**

For a hidden dimension $d$, we treat the vector as $d/2$ pairs of coordinates. 
Assume for simplicity that our model dimension $d=4$.

$$
q = [q_1, q_2, q_3, q_4]
$$

We treat the vector as **two independent 2D planes**

* Plane 1: $(q_1, q_2)$
* Plane 2: $(q_3, q_4)$

For each pair $k \in \{1, \dots, d/2\}$, we define a position-dependent angle:

$$
\theta_{i,k} = i \cdot \Theta^{-2(k-1)/d}
$$

As shown in the example above, we calculate two unique angles ($\theta_{i,1}$ and $\theta_{i,2}$).

$$
\theta_{i,1} = i \cdot \Theta^{-2(1-1)/d} = i \\
\theta_{i,2} = i \cdot \Theta^{-2(2-1)/d} = i \cdot \Theta^{-2/d}
$$


The rotation for each pair is governed by the $2 \times 2$ matrix $R_{i,k}$:

$$
R_{i,k} = \begin{bmatrix} \cos(\theta_{i,k}) & -\sin(\theta_{i,k}) \\ \sin(\theta_{i,k}) & \cos(\theta_{i,k}) \end{bmatrix}
$$

**The Full Transformation**

These blocks are assembled into a block-diagonal matrix $R_i$, which acts on the entire embedding vector:

$$
R_i = \begin{bmatrix} 
R_{i,1} & 0 & \dots & 0 \\ 0 & R_{i,2} & \dots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \dots & R_{i,d/2} 
\end{bmatrix}
$$

The rotated vector is computed as $q'^{(i)} = R_i q^{(i)}$. Crucially, when calculating attention between positions $i$ and $j$, the dot product satisfies:

$$
\langle R_i q^{(i)}, R_j k^{(j)} \rangle = \langle q^{(i)}, R_{j-i} k^{(j)} \rangle
$$

This demonstrates that the attention score depends solely on the relative displacement $j-i$.

### Practical Application

Let us apply this to your sequence. 

* String Sequence: `"the cat ate the rat"`
* Token Sequence: `[9, 0, 2, 7, 0, 7, 3, 9, 0, 6, 7]`
* Focus: 
    * Token 2 (the 'c' in 'cat') at Position $i=2$ vs. 
    * Token 6 (the 'r' in 'rat') at Position $i=9$.

For clarity, we will assume a small embedding dimension of $d=2$ (one rotation plane) and a base constant $\Theta = 10,000$.

#### Step 1: Calculating the Angles ($\theta$)

For $i=2$ (Token 'c'):

$$
\theta_{2} = 2 \cdot 10000^{0} = 2 \text{radians}
$$

For $i=9$ (Token 'r'):

$$
\theta_{9} = 9 \cdot 10000^{0} = 9 \text{radians}
$$

#### Step 2: Constructing the Rotation Matrices

For the 'c' token ($i=2$):

$$
R_2 = \begin{bmatrix} \cos(2) & -\sin(2) \\ \sin(2) & \cos(2) \end{bmatrix} \approx \begin{bmatrix} -0.416 & -0.909 \\ 0.909 & -0.416 \end{bmatrix}
$$

#### Step 3: Resulting Interaction

When the model performs self-attention between the Query of 'c' ($q^{(2)}$) and the Key of 'r' ($k^{(9)}$), the resulting score is influenced by the angular difference:

$$
\Delta\theta = \theta_9 - \theta_2 = 7 \text{ radians}
$$

The model "perceives" that the `'r'` in `'rat'` is exactly 7 positions ahead of the `'c'` in `'cat'`, allowing it to maintain the syntactic relationship between these subword units.

## Code

In [None]:
import torch
from jaxtyping import Float
import torch.nn as nn

class RoPE(nn.Module):

    def __init__(
        self,
        theta : float,
        d_k : int,
        max_seq_len : int,
        device: torch.device | None = None,
    )-> None:
        
        factory_kwargs = {"device": device}
        
        super().__init__()

        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len

        self._build_cache(**factory_kwargs)

    def _build_cache(self, device=None):
        position = torch.arange(
            self.max_seq_len,
            device=device
        ).unsqueeze(1)

        dim = torch.arange(
            0,
            self.d_k,
            2,
            device=device
        )

        inv_freq = 1.0 / (self.theta ** (dim / self.d_k))
        sinusoid_inp = position * inv_freq

        sin = torch.sin(sinusoid_inp)
        cos = torch.cos(sinusoid_inp)

        self.register_buffer("sin", sin)
        self.register_buffer("cos", cos)

    def forward(
        self,
        x: torch.Tensor,
        token_positions: torch.Tensor,
    ) -> torch.Tensor:

        batch_size, seq_len, dim = x.shape

        torch._check(
            dim % 2 == 0,
            lambda: "Embedding dimension must be even for RoPE",
        )

        if token_positions.dim() == 1:
            token_positions = token_positions.unsqueeze(0).expand(batch_size, -1)

        cos = self.cos[token_positions]
        sin = self.sin[token_positions]

        x1 = x[..., 0::2]
        x2 = x[..., 1::2]

        real = cos * x1 - sin * x2
        imag = sin * x1 + cos * x2

        # Re-interleave
        x_out = torch.stack((real, imag), dim=-1)
        x_out = x_out.flatten(-2)

        return x_out
