In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as  plt


In [114]:
n_tokens = 5
d_model = 128
x_m = torch.randn((n_tokens, d_model), dtype = torch.float32)
x_n = torch.randn((n_tokens, d_model), dtype = torch.float32)

In [115]:
W_q = nn.Linear(d_model, d_model, bias = False)
W_k = nn.Linear(d_model, d_model, bias = False)

In [116]:
half_d_model = d_model // 2
freq_range = torch.arange(half_d_model, dtype=torch.float32) 
freq_rates = torch.pow(10000, -2* freq_range / d_model) # theta

In [117]:
q = W_q(x_m).unsqueeze(0) # (B, N , D_MODEL)
k = W_k(x_n).unsqueeze(0)
bsz, seq_len, dim = q.shape

pos = torch.arange(seq_len, dtype = torch.float32)
freqs = torch.einsum('i,j->ij', pos, freq_rates)

In [118]:
sin_embd = torch.sin(freqs)
cos_embd = torch.sin(freqs)

In [130]:
q_sin, q_cos = q[..., ::2] , q[..., 1::2] # even and odd pos
k_sin, k_cos = k[..., ::2], k[..., 1::2]



q_rot = torch.cat([q_sin * cos_embd.unsqueeze(0) - q_cos * sin_embd.unsqueeze(0),
           q_sin * sin_embd.unsqueeze(0) - q_cos * cos_embd.unsqueeze(0)],
          dim = -1)

q_rot.size()

torch.Size([1, 5, 128])

In [97]:
import torch
import math

class RoPE:
    def __init__(self, dim):
        self.dim = dim
        self.freqs = self._build_frequencies(dim)

    def _build_frequencies(self, dim):
        """
        Create the frequency matrix for RoPE.
        The frequencies are generated in a geometric sequence.
        """
        half_dim = dim // 2
        freq_range = torch.arange(half_dim, dtype=torch.float32)
        freq_rates = torch.pow(10000, -2 * freq_range / dim)
        return freq_rates

    def apply(self, q, k):
        """
        Apply RoPE to the given query and key tensors.
        
        Parameters:
            q: (batch_size, num_heads, seq_len, head_dim)
            k: (batch_size, num_heads, seq_len, head_dim)
            
        Returns:
            q_rot, k_rot: RoPE applied query and key tensors
        """
        batch_size, num_heads, seq_len, head_dim = q.shape
        half_head_dim = head_dim // 2  # RoPE operates on even/odd splits of the dimension

        # Position ids: Shape (seq_len, 1)
        position_ids = torch.arange(seq_len, dtype=torch.float32, device=q.device).unsqueeze(1)

        # Calculate sin and cos embeddings based on position ids and frequency rates
        theta = torch.einsum('i,j->ij', position_ids.squeeze(-1), self.freqs.to(q.device))  # (seq_len, half_head_dim)
        sin_embed = torch.sin(theta).unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, half_head_dim)
        cos_embed = torch.cos(theta).unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, half_head_dim)

        # Split q and k into even (first half) and odd (second half) dimensions
        q1, q2 = q[..., :half_head_dim], q[..., half_head_dim:]
        k1, k2 = k[..., :half_head_dim], k[..., half_head_dim:]

        # Apply the rotational position encoding
        q_rot = torch.cat([q1 * cos_embed - q2 * sin_embed, q1 * sin_embed + q2 * cos_embed], dim=-1)
        k_rot = torch.cat([k1 * cos_embed - k2 * sin_embed, k1 * sin_embed + k2 * cos_embed], dim=-1)

        return q_rot, k_rot

# Example usage
seq_len = 128
head_dim = 64
batch_size = 32
num_heads = 8

q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)

rope = RoPE(head_dim)
q_rot, k_rot = rope.apply(q, k)
