# RoPE Implementation

In [1]:
%pip install torch



In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, Optional, Tuple

In [4]:
def build_config(
    hidden_size=128,
    num_heads=16,
    num_kv_heads=4,
    max_positional_encodings=256,
    rope_theta=10000.0,
    rms_norm_eps=1e-5,
    attention_bias=False,
    attention_dropout=0.1,
    use_qk_norm=True,
):
    assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"

    return {
        "hidden_size": hidden_size,
        "num_heads": num_heads,
        "num_kv_heads": num_kv_heads,
        "max_positional_encodings": max_positional_encodings,
        "rope_theta": rope_theta,
        "rms_norm_eps": rms_norm_eps,
        "attention_bias": attention_bias,
        "attention_dropout": attention_dropout,
        "use_qk_norm": use_qk_norm,
        "head_dim": hidden_size // num_heads,
    }

CONFIG = build_config()

In [8]:
class RoPE(nn.Module):
  def __init__(self, head_dim: int, max_position_embeddings: int = 2048, rope_theta: float = 10000.0):
    super().__init__()
    self.head_dim = head_dim
    self.max_position_embeddings = max_position_embeddings
    self.rope_theta = rope_theta

    inv_freq = 1.0 / (self.rope_theta**(torch.arange(0, self.head_dim, 2).float()/self.head_dim))
    t = torch.arange(self.max_position_embeddings, dtype=torch.float)
    freqs = torch.einsum("i,j->ij", t, inv_freq)

    self.register_buffers(freqs)

  def register_buffers(self, freqs: torch.Tensor):
    self.cos_cached = freqs.cos()
    self.sin_cached = freqs.sin()

  def rotate_half(self, x):
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).reshape_as(x)

  def apply_rotary(self, x, pos):
    cos = self.cos_cached[pos][:, None, None, :]
    sin = self.sin_cached[pos][:, None, None, :]

    cos = torch.repeat_interleave(cos, 2, dim=-1)
    sin = torch.repeat_interleave(sin, 2, dim=-1)

    return (x * cos) + (self.rotate_half(x) * sin)

In [9]:
rope = RoPE(CONFIG["head_dim"], CONFIG["max_positional_encodings"], CONFIG["rope_theta"])