In [2]:
import torch
from typing import Tuple

# Positional Encoding

## 绝对位置编码

为每个绝对位置分配一个位置编码向量，可以是手工设计的，也可以是可学习的。
example:
- Attention is all you need 中使用正余弦函数来计算位置编码
- BERT 中使用可学习的绝对位置编码。

## 相对位置编码
我们期望相对位置编码有这样的性质，对于查询 $q$ 和键 $k$, 位置编码函数 $f$, 希望 $<f(q, n), f(k, m)> = g(q, k, m - n)$, 即内积和相对距离直接相关。（因为内积是注意力计算中的核心运算）
通过修改 attention 的计算过程可以实现这一点。

## 旋转位置编码

用绝对位置编码的方式实现了相对位置编码，即对于每个绝对位置有一个位置编码向量，并且加入位置编码信息之后的向量内积直接和相对距离相关。

下面给出 llama 的旋转位置编码实现，主要的逻辑是：
- 预计算旋转矩阵 $M$, $M_{tj} = e^{i(freq[j]t)}$ shape: (s, d // 2)
- 将输入在最后一个纬度两两成组看作复数 shape: (b, s, n, d) -> (b, s, n, d // 2)
- 调整矩阵到合适广播的形状 shape: (1, s, 1, d // 2), 也就是第 i 个绝对位置的向量的第 j 个组的复数要和旋转矩阵 (i, j) 位置元素相乘
- 旋转矩阵和输入的复数表示相乘（应用旋转位置编码）
- 调整输出到实数表示形式并调整形状

In [3]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.

    
        

    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

        

    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)