# Scaled RoPE

In [1]:
from typing import Optional

import torch

def precompute_rope_params(
    seq_len: int,
    head_dim: int,
    theta_base: float = 10000.0,
    freq_config: Optional[dict] = None,
):
    """
    Precompute sin and cos tensors for RoPE with optional frequency scaling/smoothing.

    Args:
        seq_len: sequence length
        head_dim: embedding dimension (must be even)
        theta_base: base for inverse frequency calculation (default 10_000)
        freq_config: optional dict with keys:
            - original_context_length: int, original training context length
            - low_freq_factor: float, low freq threshold factor (>1)
            - high_freq_factor: float, high freq threshold factor (>1)
            - factor: float, scaling factor (>1)

    Returns:
        sin, cos: tensors of shape (seq_len, head_dim)
    """
    assert head_dim % 2 == 0, "head_dim must be even"

    half_dim = head_dim // 2
    # Compute inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(half_dim, dtype=torch.float32) / half_dim))

    if freq_config is not None:
        # Extract frequency config params
        orig_len = freq_config["original_context_length"]
        low_factor = freq_config["low_freq_factor"]
        high_factor = freq_config["high_freq_factor"]
        scale_factor = freq_config["factor"]

        # Compute wavelength
        wavelen = 2 * torch.pi / inv_freq  # shape (half_dim,)

        low_wavelen = orig_len / low_factor
        high_wavelen = orig_len / high_factor

        # Scale inverse freq for low freq bands
        inv_freq_scaled = torch.where(wavelen > low_wavelen, inv_freq / scale_factor, inv_freq)

        # Compute smooth factor for medium freq band
        smooth_factor = (orig_len / wavelen - low_factor) / (high_factor - low_factor)
        smooth_factor = smooth_factor.clamp(0.0, 1.0)

        smoothed_inv_freq = (1 - smooth_factor) * (inv_freq / scale_factor) + smooth_factor * inv_freq

        is_medium = (wavelen <= low_wavelen) & (wavelen >= high_wavelen)
        inv_freq = torch.where(is_medium, smoothed_inv_freq, inv_freq_scaled)

    # Position indices
    positions = torch.arange(seq_len, dtype=torch.float32).unsqueeze(1)  # (seq_len, 1)

    # Calculate angles
    angles = positions * inv_freq.unsqueeze(0)  # (seq_len, half_dim)

    # Duplicate angles to full head_dim (interleave)
    angles = torch.cat([angles, angles], dim=-1)  # (seq_len, head_dim)

    # Compute sin and cos
    sin = torch.sin(angles)
    cos = torch.cos(angles)

    return sin, cos


In [3]:
# Instantiate RoPE parameters

llama_3_context_len = 8192
llama_3_theta_base = 500_000

In [None]:
from Llama2_v1 import rotary_pos_emb

# Settings
batch_size = 2
num_heads = 4
head_dim = 16

# Instantiate RoPE parameters
cos, sin = precompute_rope_params(
    head_dim=head_dim,
    theta_base=llama_3_theta_base,
    context_length=llama_3_context_len
)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)
keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)

# Apply rotary position embeddings
queries_rot = rotary_pos_emb(queries, cos, sin)
keys_rot = rotary_pos_emb(keys, cos, sin)