In [None]:
'''
(https://arxiv.org/pdf/2104.09864) [Roformer: Enhanced Transformer With Rotary
Position Embedding]

I learned how RoPE works from https://huggingface.co/blog/designing-positional-encoding
the article motivates how it was developed particularly well. I encourage you read it to get intuition for 
things like how, where and why there's rotation happening, and the key desiderata for positional embeddings that lead naturally 
first to sinusoidal embeddings, and then to RoPE after that. 
    A crucial observation motivating positional embeddings is that attentions is a *set* operation
    ie. does not use any positional information, just pairwise comparison between tokens 
    independent of their position, so if you don't put positional information in the embeddings, 
    then the model that *no way of telling apart two identical words at different places in the sequence 
    ie, no way to exploiting context to infer semantic meaning* and so you will *cripple the model*. 


We start by implementing sinusoidal embeddings from "Attention is all you need" as a baseline
these are applied to [batch_size, hidden_dim] matrices representing tokens after 
they are embedded from tokens to hidden_dim in the Transformer embedding layer 

In contrast, RoPE adds to the (q,k) attention matrices directly, since only those affect inter-token 
computation and thus use positional information across tokens. 
'''


In [55]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

b, s, d = 4, 64, 768 
embeddings = torch.randn(b, s, d)

# sinusoidal embeddings, like more pos_embeds, operates on the token embeddings and not attn (like rope)
def get_sinusoidal_embeddings(s, d): 
    pos_embeddings = torch.zeros(s, d)
    base = torch.tensor(10_000)
    div_even = torch.exp(torch.log(base) * 2 * torch.arange(0, d//2)/d)
    div_odd = torch.exp(torch.log(base) * 2 * torch.arange(1, d//2+1)/d)
    pos_embeddings[:, 0::2] = (torch.arange(s)[:,None] / div_even) # want [s, d//2]
    pos_embeddings[:, 1::2] = (torch.arange(s)[:,None] / div_odd)
    return pos_embeddings # [s, d]

# do embeddings + get_sinusoidal_embeddings(s, d) and it'll broadcast over batch
# RoPE, unlike traditional positional embeddings, operates on q, k in attention directly
def add_rope_embeddings(q, k):
    b, s, d = q.shape

    thetas = 1.0 / (10000 ** (torch.arange(0, d // 2) / d // 2))

    # Compute sin and cos for each position
    positions = torch.arange(s)
    freqs = positions[:, None] * thetas # [s, d // 2], uses broadcasting to coming [s] * [d//2] outer product 
    sin, cos = freqs.sin(), freqs.cos() # [s, d // 2]

    sin = sin.repeat_interleave(2, dim=-1).unsqueeze(0) # [1, s, d]
    cos = cos.repeat_interleave(2, dim=-1).unsqueeze(0) # [1, s, d]

    # rotate every pair of features (x1, x2) -> (-x2, x1)
    def _rotate_every_two(x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(b, s, d // 2, 2)
        x1, x2 = x.unbind(-1)
        return torch.stack((-x2, x1), dim=-1).reshape(b, s, d)

    q_rot = (q * cos) + (_rotate_every_two(q) * sin)
    k_rot = (k * cos) + (_rotate_every_two(k) * sin)
    return q_rot, k_rot


q, k = torch.randn(b, s, d), torch.randn(b, s, d)
add_rope_embeddings(q, k)[0].shape  # [b, s, d]


torch.Size([4, 64, 768])