# RoPE — Rotary Position Embedding Demo

**Paper:** Su et al. (2021) — *RoFormer: Enhanced Transformer with Rotary Position Embedding*  
**Used in:** LLaMA, GPT-NeoX, Falcon, Mistral, and most modern LLMs  
**Formula:** `x_rotated = x * cos(θ) + rotate_half(x) * sin(θ)`  
**Properties:** No learnable params | Encodes relative positions | SOTA for LLMs  
**Note:** Canonical RoPE rotates Q/K in attention. This demo applies it to embeddings for unified comparison.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
class RoPEPositionalEncoding(nn.Module):
    '''
    Rotary Position Embedding (RoPE) — Su et al. (2021)
    Applies rotation to embeddings using position-dependent angles.
    theta_i = 1 / 10000^(2i/d_model)
    x_rotated = x * cos(pos * theta) + rotate_half(x) * sin(pos * theta)
    '''
    def __init__(self, d_model, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_model = d_model

        # Rotation frequencies
        theta = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        positions = torch.arange(max_seq_len, dtype=torch.float)
        angles = torch.outer(positions, theta)  # (max_seq_len, d_model/2)

        self.register_buffer('cos', torch.cos(angles))
        self.register_buffer('sin', torch.sin(angles))

    def _rotate_half(self, x):
        half = x.shape[-1] // 2
        x1, x2 = x[..., :half], x[..., half:]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, x):
        seq_len = x.size(1)
        cos = torch.cat([self.cos[:seq_len], self.cos[:seq_len]], dim=-1).unsqueeze(0)
        sin = torch.cat([self.sin[:seq_len], self.sin[:seq_len]], dim=-1).unsqueeze(0)
        return self.dropout(x * cos + self._rotate_half(x) * sin)

print('RoPEPositionalEncoding defined.')

In [None]:
# Sanity check — shape + relative position property
d_model, seq_len, batch = 64, 50, 4
pe_layer = RoPEPositionalEncoding(d_model, max_seq_len=512, dropout=0.0)

x = torch.randn(batch, seq_len, d_model)
out = pe_layer(x)

print(f'Input shape : {x.shape}')
print(f'Output shape: {out.shape}')
print(f'Learnable params: {sum(p.numel() for p in pe_layer.parameters())}')  # should be 0

# RoPE property: rotation preserves vector magnitude
input_norm = x[0, :5].norm(dim=-1)
output_norm = out[0, :5].norm(dim=-1)
print(f'\nInput norms  (first 5 pos): {input_norm.detach().tolist()}')
print(f'Output norms (first 5 pos): {output_norm.detach().tolist()}')
print('(Norms should be ~equal — rotation preserves magnitude)')

In [None]:
# Heatmap — visualize cos/sin patterns and effective PE
d_model, seq_len = 64, 60
pe_layer = RoPEPositionalEncoding(d_model, max_seq_len=512, dropout=0.0)

# The cos component across positions and dimensions
cos_matrix = torch.cat([pe_layer.cos[:seq_len], pe_layer.cos[:seq_len]], dim=-1).numpy()

# What RoPE does to a uniform input
uniform_x = torch.ones(1, seq_len, d_model)
with torch.no_grad():
    encoded = pe_layer(uniform_x)
pe_matrix = encoded[0].numpy()

fig, axes = plt.subplots(1, 2, figsize=(16, 4))

im0 = axes[0].imshow(cos_matrix.T, aspect='auto', cmap='RdYlBu', origin='lower')
axes[0].set_xlabel('Position')
axes[0].set_ylabel('Dimension')
axes[0].set_title('RoPE cos(theta) pattern')
plt.colorbar(im0, ax=axes[0])

im1 = axes[1].imshow(pe_matrix.T, aspect='auto', cmap='RdYlBu', origin='lower')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Dimension')
axes[1].set_title('RoPE applied to uniform input')
plt.colorbar(im1, ax=axes[1])

plt.tight_layout()
plt.savefig('demo_rope_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: demo_rope_heatmap.png')

In [None]:
print('=== RoPE Summary ===')
print('Learnable parameters: 0')
print('Encodes relative positions: YES')
print('Preserves vector magnitude: YES (rotation is isometric)')
print('Generalizes to unseen lengths: YES')
print('Used in: LLaMA, Mistral, Falcon, GPT-NeoX')
print()
print('Ready to use in experiments.')
print('Import: from PE.rope import RoPEPositionalEncoding')