In [None]:
import math

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        
        # 构建旋转频率
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # 预计算缓存以加速推理
        self._set_cos_sin_cache(max_position_embeddings)
        
    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        
        # 计算不同位置的频率
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        # 计算旋转角度的正弦和余弦值
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        self.register_buffer("cos_cached", cos.float())
        self.register_buffer("sin_cached", sin.float())
        
    def forward(self, x, seq_len=None):
        # x: [batch_size, seq_len, num_heads, head_dim]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len)
            
        return (
            self.cos_cached[:seq_len].to(x.device),
            self.sin_cached[:seq_len].to(x.device)
        )

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # q, k: [batch_size, seq_len, num_heads, head_dim]
    # cos, sin: [seq_len, head_dim]
    # position_ids: [batch_size, seq_len]
    
    # 获取q和k的形状信息
    batch_size, seq_length, num_heads, head_dim = q.shape
    
    # 根据position_ids获取对应位置的cos和sin
    cos = cos.index_select(0, position_ids.reshape(-1)).reshape(batch_size, seq_length, 1, head_dim)
    sin = sin.index_select(0, position_ids.reshape(-1)).reshape(batch_size, seq_length, 1, head_dim)
    print(cos.shape)
    # 将head_dim维度拆分为两部分
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    
    return q_embed, k_embed

def rotate_half(x):
    # 旋转向量的一半维度
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

In [None]:
import torch 
import torch.nn as nn
class RopeEmbed(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        ## xita每二个建一个,后面拼起来，因为每2个维度下标对应一个xita，m每个token建一个
        xita = 1.0/(base ** (2 * torch.arange(0, dim, 2) /dim))
        # print(xita.shape)
        self.register_buffer("xita", xita)

    def forward(self,hidden_states):
        # hidden_states [bs, seqlen, dim]
        seq_len = hidden_states.shape[1]
        bs = hidden_states.shape[0]
        xita_expanded = self.xita[None,:,None].expand(hidden_states.shape[0],-1,1)
        # xita_expanded = self.xita.expand()
        position_indexes_expanded = torch.arange(seq_len)[None,None,:].expand(bs, 1, -1)
        freqs = (xita_expanded.float() @ position_indexes_expanded.float()).transpose(1,2)
        # print(freqs.shape)
        emb = torch.cat((freqs, freqs), dim=-1)
        emb_cos = emb.cos()
        emb_sin = emb.sin()
        return emb_cos, emb_sin

def rotate(hidden_states):
    x1 ,x2 = hidden_states[..., :hidden_states.shape[-1]//2], hidden_states[..., hidden_states.shape[-1]//2:]
    return torch.cat((-x2,x1),dim=-1)
def apply_rotary_pos_emb(q,k,cos,sin):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    ## q [bs, seqlen, head,headdim]
    ## k [bs, seqlen, head,headdim]
    ## cos,sin [bs, seqlen, headdim]
    cos = cos.unsqueeze(2)
    sin = sin.unsqueeze(2)
    q_rot = (q * cos) + (rotate(q) * sin)
    k_rot = (k * cos) + (rotate(k) * sin)
    return q_rot, k_rot
ropeemb = RopeEmbed(128)
cos,sin = ropeemb(torch.randn(2,128,128))
q,k = torch.randn(2,128,8,128),torch.randn(2,128,8,128)
apply_rotary_pos_emb(q,k,cos,sin)