# Details about RoPE embedding

- RoPE embedding needs to be applied to **every attention layer**. It's operating in the attention score space, not affecting the weighted_v
- The precompute freqs_cis should have length of $2*max\_seq\_len$, as (m-n) could range over $2*max\_seq\_len$.

# Implementation of RopE Positional Embedding

The RoPE embedding is applied for each individual head, so after spliting head


In [None]:
import torch

# 1. Create precomputed RoPE embeddings
def precompute_freqs_cis(seq_len, head_dim, base=10000.0): 
    '''Output frequency in complex space: cos + i sin
    Output dimension is [seq_len, head_dim/2]'''
    freqs = 1.0 / base ** (torch.arange(0, head_dim, 2).float() / head_dim)
    t = torch.arange(seq_len)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


# 2. Apply RoPE rotation
def apply_rope_rotation(xq, xk, freqs_cis):
    '''
    xq and xk has dimsion [Batch_size, Seq_len, Head_num, Head_dim],
    freqs_cis has dimsion [Seq_len, Head_dim/2]. We need to reshape freqs_cis before broadcasting
    '''

    ## Reshape xq & xk to [B,S,D/2,2]
    xq_ = xq.view(*xq.shape[:-1], -1, 2)
    xk_ = xk.view(*xq.shape[:-1], -1, 2)

    ## Convert to complex numbers [B,S,H,D/2]
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)

    print(xq_.shape)
    ## Apply rotation in complex space - Complex multiplication
    # xq_ = xq_ * freqs_cis

    ## reshape freqs_cis for broadcasting
    freqs_cis = xq_ * freqs_cis.unsqueeze(0).unsqueeze(2)
    xq_ = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_ = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    # usually there's xq_.astype(xq) and xk_.astype(xk) to make sure they sure they are on same data and device type 
    return xq_, xk_



# Example Usecase: RoPE + Attention

In [146]:
import torch
import math
from torch import nn

class MultiHeadAttentionWithRoPE(nn.Module):
    def __init__(self, head_num, dim, max_seq_len):
        super().__init__()
        self.head_num = head_num
        self.dim = dim
        self.head_dim = dim / head_num
        self.Wq = nn.Linear(dim, dim)
        self.Wk = nn.Linear(dim, dim)
        self.Wv = nn.Linear(dim, dim)

        self.precompute_freqs_cis = precompute_freqs_cis(max_seq_len*2, self.head_dim)

    def forward(self, x, start_pos=0):
        # int put x has dimension of [B, S, D]
        q = self.Wq(x)
        k = self.Wk(x)
        v = self.Wv(x)

        # reshape to [B, S, H, Hd]
        q = q.reshape(*q.shape[:-1], self.head_num, -1)
        k = k.reshape(*k.shape[:-1], self.head_num, -1)
        v = v.reshape(*v.shape[:-1], self.head_num, -1)

        # apply RoPE embedding
        q, k = apply_rope_rotation(q, k, self.precompute_freqs_cis[start_pos : start_pos+x.shape[1]])

        # compute attention
        q = q.reshape(q.shape[0], q.shape[2], q.shape[1], q.shape[3])  # reshape to [B, H, S, Hd]
        k = k.reshape(k.shape[0], k.shape[2], k.shape[1], k.shape[3])
        v = v.reshape(v.shape[0], v.shape[2], v.shape[1], v.shape[3])

        attention_scores = torch.matmul(q, k.transpose(2,3)) / math.sqrt(self.head_dim) # [B, H, Sq, Sk]
        attention_scores = torch.softmax(attention_scores, dim=-1)  # [B, H, Sq, Sk]

        weighted_v = torch.matmul(attention_scores, v)  # [B, H, Sq, Hd]
        weighted_v = weighted_v.reshape(weighted_v.shape[0], weighted_v.shape[2], -1)  # [B, Sq, D]
        return weighted_v



In [147]:

B, S, D = 5, 100, 64*8
H = 8 

x = torch.randn(B, S, D).float()

model = MultiHeadAttentionWithRoPE(H, D, S)
output = model(x)
print(output.shape)  # Expected output shape: [B, S, D]

torch.Size([5, 100, 8, 32])
torch.Size([5, 100, 512])
