# rotary position embedding

In [1]:
!pip install torch;

Looking in indexes: http://mirrors.cloud.aliyuncs.com/pypi/simple/
[0m

## 1. math from paper
basic equation
$$f(q,m)=
\begin{pmatrix}
 \cos m\theta & -\sin m\theta  \\
 \sin m\theta & \cos m\theta
\end{pmatrix}
\begin{pmatrix}
 q_{0}  \\
 q_{1} 
\end{pmatrix}
$$

to expand to full matrix

$$
=\begin{pmatrix}
 \cos m\theta_{0} & -\sin m\theta_{0} & 0 & 0 & \cdots & 0 & 0 \\
 \sin m\theta_{0} & \cos m\theta_{0} & 0 & 0 & \cdots & 0 & 0 \\
 0 & 0 & \cos m\theta_{1} & -\sin m\theta_{1} & \cdots & 0 & 0 \\
 0 & 0 & \sin m\theta_{1} & \cos m\theta_{1} & \cdots  & 0 & 0 \\
 \vdots &  \vdots &  \vdots &  \vdots & \ddots & \vdots & \vdots \\
 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2} & -\sin m\theta_{d/2} \\
 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2} & \cos m\theta_{d/2}\\
\end{pmatrix}
\begin{pmatrix}
 q_{0} \\
 q_{1} \\
 q_{2} \\
 q_{3} \\
 \vdots \\
 q_{d/2-2} \\
 q_{d/2-1} \\
\end{pmatrix}
$$

with rewrite to be computational efficient

$$
\begin{pmatrix}
 q_{0} \\
 q_{1} \\
 q_{2} \\
 q_{3} \\
 \vdots \\
 q_{d-2} \\
 q_{d-1} \\
\end{pmatrix}
\bigotimes
\begin{pmatrix}
 cos m\theta_{0} \\
 cos m\theta_{0} \\
 cos m\theta_{1} \\
 cos m\theta_{1} \\
 \vdots \\
 cos m\theta_{d/2-1} \\
 cos m\theta_{d/2-1} \\
\end{pmatrix}
+
\begin{pmatrix}
 -q_{1} \\
 q_{0} \\
 -q_{3} \\
 q_{2} \\
 \vdots \\
 -q_{d-1} \\
 q_{d-2} \\
\end{pmatrix}
\bigotimes
\begin{pmatrix}
 sin m\theta_{0} \\
 sin m\theta_{0} \\
 sin m\theta_{1} \\
 sin m\theta_{1} \\
 \vdots \\
 sin m\theta_{d/2-1} \\
 sin m\theta_{d/2-1} \\
\end{pmatrix}
$$

In [2]:
import torch
import numpy as np
import math

# config
dim = 4096 # also embed_size
max_seq_len = 2048 # 最长输入长度
bsz = 16 # batch_size
seq_len = 6 # 当前输入长度
num_heads = 32 #
head_dim = dim // num_heads # 128

## 2. implementation at roformer (rope original paper)

cal sinusoidal position

https://github.com/JunnYu/RoFormer_pytorch/blob/roformer_v2/src/roformer/modeling_roformer.py#L156

apply rotary

https://github.com/JunnYu/RoFormer_pytorch/blob/roformer_v2/src/roformer/modeling_roformer.py#L441

the code follows math equation by rotating in pairs

In [3]:
import torch

def roformer_apply_rotary(x):
    
    # RoFormerSinusoidalPositionalEmbedding
    position_enc = np.array(
        [
            [pos / np.power(10000, 2 * (j // 2) / head_dim) for j in range(head_dim)]
            for pos in range(seq_len)
        ]
    )
    sin = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    cos = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
#     sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
#     out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
#     out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out = torch.cat((sin,cos),dim=-1) # [seq_len,head_dim]
    
    
    # RoFormerEncoder.forward
    sinusoidal_pos = out[None,None,:,:].chunk(2, dim=-1) # [1, 1, seq_len, head_dim // 2] 
    
    
    # RoFormerSelfAttention.apply_rotary
    sin,cos = sinusoidal_pos
    x1, x2 = x[..., 0::2], x[..., 1::2]
    return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1), cos, sin

## 3. implementation at GPTNeo
https://github.com/EleutherAI/gpt-neo/blob/master/models/layers.py#L355

this is **different** from paper cause it uses rotate_half instead of rotate_every_two

confirmed by https://github.com/EleutherAI/gpt-neox/pull/241 and discussed in https://github.com/EleutherAI/gpt-neox/issues/834 that it is efficient without performance loss

## 4. implementation at facebook llama

llama paper indicates that it is inspired from GPTNeo

but the rope use complex number trick that actually follow the original paper, NOT GPTNeo rotate_half

https://github.com/facebookresearch/llama/blob/main/llama/model.py#L63

In [4]:
def llama_apply_rotary_emb(x):
    # L123 llama use x shape (bsz, seqlen, self.n_local_heads, self.head_dim)
    # so here transpose back
    x_ = x.transpose(1, 2) # (bsz, seqlen, num_heads, head_dim)
    
    # precompute_freqs_cis
    theta = 10000
    freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))
    t = torch.arange(max_seq_len*2)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    
    
    # Transformer.forward after tok_embeddings
    freqs_cis = freqs_cis[0:seq_len] # (seq_len, head_dim/2)
    
    
    # apply_rotary_emb
    x_ = x_.float().reshape(*x_.shape[:-1], -1, 2) # last dimension [x0,x1,x2,..] to [[x0,x1],[x2,]..] pairs, (bsz, seqlen, num_heads, head_dim/2, 2)
    x_ = torch.view_as_complex(x_) # (bsz, seqlen, num_heads, head_dim/2)
    
    
    # reshape_for_broadcast
    ndim = x_.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x_.shape[1], x_.shape[-1])
    shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x_.shape)]
    freqs_cis =freqs_cis.view(*shape) # (1, seq_len, 1, head_dim/2)
    
    
    o = torch.view_as_real(x_ * freqs_cis) # (bsz, seq_len, num_heads, head_dim/2, 2)
    o = o.flatten(3) # (bsz, seq_len, num_heads, head_dim)
    return o.transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)


## 5. implementation at transformers Llama

https://github.com/fpgaminer/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

In [5]:
def transformers_apply_rotary_pos_emb(x):
    # LlamaModel.forward
    position_ids = torch.arange(0, seq_len) #(1, seq_len)
    position_ids = position_ids.unsqueeze(0).view(-1, seq_len) 
    
    
    # LlamaRotaryEmbedding
    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float()/ head_dim)) # (max_seq_len)
    t = torch.arange(max_seq_len, dtype=inv_freq.dtype)
    freqs = torch.einsum("i,j->ij", t, inv_freq) # (max_seq_len,max_seq_len)
    emb = torch.cat((freqs, freqs), dim=-1)
    cos_cached = emb.cos()[None, None, :, :]
    sin_cached = emb.sin()[None, None, :, :]
    cos = cos_cached[:,:,:seq_len,...]
    sin = sin_cached[:,:,:seq_len,...]
    
    
    def rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    
    # apply_rotary_pos_emb
    cos = cos.squeeze(1).squeeze(0)  
    sin = sin.squeeze(1).squeeze(0)  
    cos = cos[position_ids].unsqueeze(1)  # (1, 1, seq_len, head_dim)
    sin = sin[position_ids].unsqueeze(1)  # (1, 1, seq_len, head_dim)
    
    return (x * cos) + (rotate_half(x) * sin), cos, sin


## 6. test

prove that
- RoPE implementation differenct leads to different embed result
- roformer = facebook llama != transformers llama / GPTNeo

In [6]:
with torch.no_grad():
    x = torch.randn((bsz, seq_len, dim))
    q = torch.nn.Linear(dim,dim,bias=False)
    query_states = q(x).view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)

    e1,cos1,sin1 = roformer_apply_rotary(query_states)
    e2 = llama_apply_rotary_emb(query_states)
    e3,cos3, sin3 = transformers_apply_rotary_pos_emb(query_states)

    # e1 = e2 != e3
    print(torch.allclose(e1,e2,atol=1e-5))
    print(torch.allclose(e2,e3,atol=1e-4))


True
False


prove that
- sinusoidal_pos calculation result is same, only differ in shape
- cos3 = [cos1,cos1], sin3 = [sin1,sin1]

In [7]:
with torch.no_grad():    
    print(torch.allclose(cos1,cos3.chunk(2,dim=-1)[0]))
    print(torch.allclose(cos1,cos3.chunk(2,dim=-1)[1]))
    print(torch.allclose(sin1,sin3.chunk(2,dim=-1)[0]))
    print(torch.allclose(sin1,sin3.chunk(2,dim=-1)[1]))

True
True
True
True


prove that 

- to make sliced rotary work the same as the original, transformers' magic is this permute function during weight conversion https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101
- confirmed by https://discuss.huggingface.co/t/why-llama-weight-in-huggingface-need-to-do-permute-on-wq-wk/37643

- so using transformers LlamaForCausalLM for inference will works the same as facebook llama

In [8]:
with torch.no_grad():
    x = torch.randn((bsz, seq_len, dim))

    # use roformer    
    wq = torch.nn.Linear(dim,dim,bias=False)
    wk = torch.nn.Linear(dim,dim,bias=False)
    wv = torch.nn.Linear(dim,dim,bias=False)

    query_states = wq(x).view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)
    key_states = wk(x).view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)
    values = wv(x).view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)
  
    eq,_,_ = roformer_apply_rotary(query_states)
    ek,_,_ = roformer_apply_rotary(key_states)
    
    scores = torch.matmul(eq, ek.transpose(2, 3)) / math.sqrt(head_dim)
    scores = torch.nn.functional.softmax(scores.float(), dim=-1)
    output = torch.matmul(scores, values)
    
    
    # use transformers Llama
    # permute wq and wk
    def permute(w):
        return w.view(num_heads, head_dim // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
    wq2 = torch.nn.Linear(dim,dim,bias=False)
    wq2.weight.copy_(permute(wq.weight))
    wk2 = torch.nn.Linear(dim,dim,bias=False)
    wk2.weight.copy_(permute(wk.weight))

    query_states2 = wq2(x).view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)
    key_states2 = wk2(x).view(bsz, seq_len, num_heads, head_dim).transpose(1, 2) # (bsz, num_heads, seq_len, head_dim)

    
    eq2,_,_ = transformers_apply_rotary_pos_emb(query_states2)
    ek2,_,_ = transformers_apply_rotary_pos_emb(key_states2) 

    scores2 = torch.matmul(eq2, ek2.transpose(2, 3)) / math.sqrt(head_dim)
    scores2 = torch.nn.functional.softmax(scores2, dim=-1)
    output2 = torch.matmul(scores2, values)  
    
    
    print('wq {} wq2'.format('==' if torch.allclose(wq.weight,wq2.weight) else '!='))
    print('query_states {} query_states2'.format('==' if torch.allclose(query_states,query_states2) else '!='))
    print('key_states {} key_states2'.format('==' if torch.allclose(key_states,key_states2) else '!='))   
    print('eq {} eq2'.format('==' if torch.allclose(eq,eq2) else '!='))
    print('ek {} ek2'.format('==' if torch.allclose(ek,ek2) else '!='))
    print('scores {} scores2'.format('==' if torch.allclose(scores,scores2) else '!='))
    print('output {} output2'.format('==' if torch.allclose(output,output2,atol=1e-6) else '!=')) 

wq != wq2
query_states != query_states2
key_states != key_states2
eq != eq2
ek != ek2
scores == scores2
output == output2
