In [None]:
import math

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        
        # 构建旋转频率
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        
        # 预计算缓存以加速推理
        self._set_cos_sin_cache(max_position_embeddings)
        
    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        
        # 计算不同位置的频率
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        # 计算旋转角度的正弦和余弦值
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        self.register_buffer("cos_cached", cos.float())
        self.register_buffer("sin_cached", sin.float())
        
    def forward(self, x, seq_len=None):
        # x: [batch_size, seq_len, num_heads, head_dim]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len)
            
        return (
            self.cos_cached[:seq_len].to(x.device),
            self.sin_cached[:seq_len].to(x.device)
        )

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # q, k: [batch_size, seq_len, num_heads, head_dim]
    # cos, sin: [seq_len, head_dim]
    # position_ids: [batch_size, seq_len]
    
    # 获取q和k的形状信息
    batch_size, seq_length, num_heads, head_dim = q.shape
    
    # 根据position_ids获取对应位置的cos和sin
    cos = cos.index_select(0, position_ids.reshape(-1)).reshape(batch_size, seq_length, 1, head_dim)
    sin = sin.index_select(0, position_ids.reshape(-1)).reshape(batch_size, seq_length, 1, head_dim)
    
    # 将head_dim维度拆分为两部分
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    
    return q_embed, k_embed

def rotate_half(x):
    # 旋转向量的一半维度
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)