In [1]:
import torch

For an input with dimension $d_h$, RoPE will apply $\frac{d_h}{2}$ 2D rotation to each consecutive pair of dimension.

$i$-th 2D rotation is applied to dimension $(2i, 2i + 1)$:


| $i$-th 2D rotation  | Dimensions $(2i, 2i + 1)$ |
|----------------|----------------------|
| 1              | $(1,\, 2)$           |
| 2              | $(3,\, 4)$           |
| 3              | $(5,\, 6)$           |
| $\vdots$       | $\vdots$             |


There will be $\frac{d_h}{2}$ frequency basis for input with dimension $d_h$:

$\Theta = \{\theta_1, \theta_2, \ldots, \theta_{\frac{d_h}{2}} \}$ where $\theta_i = \text{base}^{-\frac{2i}{d_h}}$.

The exact rotation matrices then dependends on the token location. For token position $t$, the $i$-th 2D rotation matrix will be 

$R(t, i) = \begin{bmatrix}
\cos t \theta_i & -\sin t \theta_i \\
\sin t \theta_i & \cos t \theta_i
\end{bmatrix}$

The whole rotation matrix for token at position $t$ is 
$\text{RoPE}(t) =
\begin{bmatrix}
R_{t,1} & 0      & 0      & \cdots & 0 \\
0      & R_{t,2} & 0      & \cdots & 0 \\
0      & 0      & R_{t,3} & \cdots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
0      & 0      & 0      & \cdots & R_{t,\,\frac{d_h}{2}}
\end{bmatrix} \in \reals^{d_h \times d_h}$

In [2]:
torch.manual_seed(42)
seq_len, head_dim = 20, 8

query = torch.rand((seq_len, head_dim))
key = torch.rand((seq_len, head_dim))

# define theta
base = 100.0
i_range_idx = torch.arange(0, head_dim / 2) 
inv_freq = 1.0 / (base ** (2 * i_range_idx / head_dim)) # [head_dim / 2, ]


In [3]:
R_matrices = []

for t in range(seq_len):
    # List of 2x2 blocks for this position
    blocks = []
    for i in range(len(inv_freq)):
        theta_i = inv_freq[i]
        angle = t * theta_i

        # 2x2 rotation block
        R_block = torch.tensor([
            [torch.cos(angle), -torch.sin(angle)],
            [torch.sin(angle),  torch.cos(angle)]
        ])

        blocks.append(R_block)

    # Stack all blocks into block diagonal matrix
    R_t = torch.block_diag(*blocks)   # => [head_dim, head_dim]

    R_matrices.append(R_t)

In [4]:
R_matrices[3]

tensor([[-0.9900, -0.1411,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.1411, -0.9900,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.5828, -0.8126,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.8126,  0.5828,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.9553, -0.2955,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.2955,  0.9553,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9955, -0.0947],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0947,  0.9955]])

In [5]:
# Apply RoPE explicitly using the full block-diagonal matrices
query_rot_explicit = torch.zeros_like(query)
key_rot_explicit   = torch.zeros_like(key)

for t in range(seq_len):
    R_t = R_matrices[t]                   # [head_dim, head_dim]
    query_rot_explicit[t] = query[t] @ R_t        # [1, head_dim]
    key_rot_explicit[t]   = key[t]   @ R_t        # [1, head_dim]


For a dimension 8 input, we have 4 frequency

$\Theta = \{\theta_1, \theta_2, \theta_3, \theta_4 \}$

$
\boldsymbol{x}_t =
\begin{bmatrix}
x_{t,1} &
x_{t,2} &
x_{t,3} &
x_{t,4} &
x_{t,5} &
x_{t,6} &
x_{t,7} &
x_{t,8}
\end{bmatrix}.
$

$
\mathrm{RoPE}(t) =
\begin{bmatrix}
\cos(t\theta_1) & -\sin(t\theta_1) & 0 & 0 & 0 & 0 & 0 & 0 \\
\sin(t\theta_1) & \;\cos(t\theta_1) & 0 & 0 & 0 & 0 & 0 & 0 \\
0 & 0 & \cos(t\theta_2) & -\sin(t\theta_2) & 0 & 0 & 0 & 0 \\
0 & 0 & \sin(t\theta_2) & \;\cos(t\theta_2) & 0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 & \cos(t\theta_3) & -\sin(t\theta_3) & 0 & 0 \\
0 & 0 & 0 & 0 & \sin(t\theta_3) & \;\cos(t\theta_3) & 0 & 0 \\
0 & 0 & 0 & 0 & 0 & 0 & \cos(t\theta_4) & -\sin(t\theta_4) \\
0 & 0 & 0 & 0 & 0 & 0 & \sin(t\theta_4) & \;\cos(t\theta_4)
\end{bmatrix}.
$

$
\boldsymbol{x}_t  \mathrm{RoPE}(t) =
\begin{bmatrix}
x_{t,1}\cos(t\theta_1) - x_{t,2}\sin(t\theta_1) \\[0.5em]
x_{t,1}\sin(t\theta_1) + x_{t,2}\cos(t\theta_1) \\[0.5em]
x_{t,3}\cos(t\theta_2) - x_{t,4}\sin(t\theta_2) \\[0.5em]
x_{t,3}\sin(t\theta_2) + x_{t,4}\cos(t\theta_2) \\[0.5em]
x_{t,5}\cos(t\theta_3) - x_{t,6}\sin(t\theta_3) \\[0.5em]
x_{t,5}\sin(t\theta_3) + x_{t,6}\cos(t\theta_3) \\[0.5em]
x_{t,7}\cos(t\theta_4) - x_{t,8}\sin(t\theta_4) \\[0.5em]
x_{t,7}\sin(t\theta_4) + x_{t,8}\cos(t\theta_4)
\end{bmatrix}^T
$

$
\boldsymbol{x}_t^{(\mathrm{odd})}=
\begin{bmatrix}
x_{t,1} & x_{t,3} & x_{t,5} & x_{t,7}
\end{bmatrix},
\qquad
\boldsymbol{x}_t^{(\mathrm{even})}=
\begin{bmatrix}
x_{t,2} & x_{t,4} & x_{t,6} & x_{t,8}
\end{bmatrix}.
$

$
\cos(t\Theta) =
\begin{bmatrix}
\cos(t\theta_1) &
\cos(t\theta_2) &
\cos(t\theta_3) &
\cos(t\theta_4)
\end{bmatrix}, \qquad
\sin(t\Theta) =
\begin{bmatrix}
\sin(t\theta_1) &
\sin(t\theta_2) &
\sin(t\theta_3) &
\sin(t\theta_4)
\end{bmatrix}.
$

$\boldsymbol{x}_t^{(\mathrm{odd})} \odot \cos(t\Theta)=
\begin{bmatrix}
x_{t,1}\cos(t\theta_1) &
x_{t,3}\cos(t\theta_2) &
x_{t,5}\cos(t\theta_3) &
x_{t,7}\cos(t\theta_4)
\end{bmatrix}$

$\boldsymbol{x}_t^{(\mathrm{odd})} \odot \sin(t\Theta)=
\begin{bmatrix}
x_{t,1}\sin(t\theta_1) &
x_{t,3}\sin(t\theta_2) &
x_{t,5}\sin(t\theta_3) &
x_{t,7}\sin(t\theta_4)
\end{bmatrix}$

$\boldsymbol{x}_t^{(\mathrm{even})} \odot \sin(t\Theta)=
\begin{bmatrix}
x_{t,2}\sin(t\theta_1) &
x_{t,4}\sin(t\theta_2) &
x_{t,6}\sin(t\theta_3) &
x_{t,8}\sin(t\theta_4)
\end{bmatrix}.$

$\boldsymbol{x}_t^{(\mathrm{even})} \odot \cos(t\Theta)=
\begin{bmatrix}
x_{t,2}\cos(t\theta_1) &
x_{t,4}\cos(t\theta_2) &
x_{t,6}\cos(t\theta_3) &
x_{t,8}\cos(t\theta_4)
\end{bmatrix}$


$
\boldsymbol{x}'^{(\mathrm{odd})}_t=
\boldsymbol{x}_t^{(\mathrm{odd})} \odot \cos(t\Theta)-\boldsymbol{x}_t^{(\mathrm{even})} \odot \sin(t\Theta)=
\begin{bmatrix}
x_{t,1}\cos(t\theta_1) - x_{t,2}\sin(t\theta_1) \\
x_{t,3}\cos(t\theta_2) - x_{t,4}\sin(t\theta_2) \\
x_{t,5}\cos(t\theta_3) - x_{t,6}\sin(t\theta_3) \\
x_{t,7}\cos(t\theta_4) - x_{t,8}\sin(t\theta_4)
\end{bmatrix}.
$

$
\boldsymbol{x}'^{(\mathrm{even})}_t=
\boldsymbol{x}_t^{(\mathrm{odd})} \odot \sin(t\Theta)
+
\boldsymbol{x}_t^{(\mathrm{even})} \odot \cos(t\Theta)=
\begin{bmatrix}
x_{t,1}\sin(t\theta_1) + x_{t,2}\cos(t\theta_1) \\
x_{t,3}\sin(t\theta_2) + x_{t,4}\cos(t\theta_2) \\
x_{t,5}\sin(t\theta_3) + x_{t,6}\cos(t\theta_3) \\
x_{t,7}\sin(t\theta_4) + x_{t,8}\cos(t\theta_4)
\end{bmatrix}.
$


In [8]:
# Fast RoPE implementation (matching the explicit version's structure)

# Precompute frequencies for all t
# cos_table[t] and sin_table[t] have shape [head_dim/2]
cos_table = torch.zeros((seq_len, int(head_dim / 2)))
sin_table = torch.zeros((seq_len, int(head_dim / 2)))

for t in range(seq_len):
    angles = t * inv_freq                      # [head_dim/2]
    cos_table[t] = torch.cos(angles)
    sin_table[t] = torch.sin(angles)


query_rot_fast = torch.zeros_like(query)
key_rot_fast   = torch.zeros_like(key)

# Apply RoPE fast version
for t in range(seq_len):
    x_q = query[t]
    x_k = key[t]

    first_q  = x_q[0::2]   # x1, x3, x5, x7
    second_q = x_q[1::2]   # x2, x4, x6, x8

    first_k  = x_k[0::2]
    second_k = x_k[1::2]

    cos_t = cos_table[t]
    sin_t = sin_table[t]

    q_first_rot  = first_q  * cos_t - second_q * sin_t
    q_second_rot = first_q  * sin_t + second_q * cos_t

    k_first_rot  = first_k  * cos_t - second_k * sin_t
    k_second_rot = first_k  * sin_t + second_k * cos_t

    q_rot = torch.empty_like(x_q)
    k_rot = torch.empty_like(x_k)

    q_rot[0::2] = q_first_rot
    q_rot[1::2] = q_second_rot

    k_rot[0::2] = k_first_rot
    k_rot[1::2] = k_second_rot

    query_rot_fast[t] = q_rot
    key_rot_fast[t]   = k_rot

