# LLaMA
LLaMA(大型语言模型 Meta AI) 是一系列先进的基础语言模型，参数范围从 70 亿到 650 亿。这些模型体积较小，但性能卓越，显著降低了实验新方法、验证他人工作的计算能力和资源需求，同时探索创新用例。

Llama使用了数万亿token的数据进行训练，证明了使用公开数据集就能够训练出sota(state of art最先进)模型，使用了大量的无标注数据，使用到的数据主要如下：
- 67.0% CommonCrawl
- 15.0% C4
- 4.5% GitHub
- 4.5% Wikipedia 4.5% 维基百科
- 4.5% Books 4.5% 书籍
- 2.5% ArXiv
- 2.0% StackExchange

通过使用多样化的数据进行训练，`LLaMA-13B` 在大多数基准测试中优于 GPT-3（175B） ，而 `LLaMA-65B` 则与最佳模型 `Chinchilla-70B` 和 `PaLM-540B` 相当。

1. 官方文档：[Llama documention](https://www.llama.com/docs/overview)
2. HuggingFace官方文档：[LLaMA](https://huggingface.co/docs/transformers/main/en/model_doc/llama)

In [1]:
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self


def find_multiple(n: int, k: int) -> int:
    if n % k == 0:
        return n
    return n + k - (n % k)

MaskCache = torch.Tensor
RoPECache = torch.Tensor
KVCache = Tuple[torch.Tensor, torch.Tensor]
llama_configs = {
    "0B": dict(n_layer=2, n_head=4, n_embd=128),
    "7B": dict(n_layer=32, n_head=32, n_embd=4096),
    "13B": dict(n_layer=40, n_head=40, n_embd=5120),
    "30B": dict(n_layer=60, n_head=52, n_embd=6656),
    "65B": dict(n_layer=80, n_head=64, n_embd=8192),
}

@dataclass
class LLaMAConfig:
    block_size: int = 2048
    vocab_size: int = 32000
    padded_vocab_size: Optional[int] = None
    n_layer: int = 32
    n_head: int = 32
    n_embd: int = 4096

    def __post_init__(self):
        if self.padded_vocab_size is None:
            self.padded_vocab_size = find_multiple(self.vocab_size, 64)

    @classmethod
    def from_name(cls, name: str) -> Self:
        return cls(**llama_configs[name])


## 模型结构
`LLaMA`模型基于`Transformer`模型`Decoder`部分，做出了如下改进：
- 前置归一化(Pre-Normalization)，并使用`RMSNorm`作为归一化函数
- 使用`SwiGLU`作为激活函数
- 使用旋转位置编码`RoPE`

模型结构如下图所示：
<p align="center">
    <img src="./_img/llama_arch.png" width="65%"/>
</p>

接下来对这三个部分做详细介绍：


### Pre-Normalization
Pre-Normalization是前置归一化技术。原始`Transformer`中的归一化方法属于`Post-Normalization`，`Pre-Normalization`就是讲归一化方法应用在每个子层之前（注意力层和MLP层），计算公式如下：
$$Pre-Norm(x) = x + Sublayer(Norm(x))$$
不同`Normalization`方法在模型中的位置对比如下图所示：
<p align="center">
    <img src="./_img/diff_norm.png" width="40%"/>
不同归一化模块的位置对比
</p>

相较于 `Post-Normalization`，`Pre-Normalization` 直接把每个子层加在了归一化模块之后，仅仅对输入的表示进行了归一化，从而可以防止模型的**梯度爆炸**或者**梯度消失**现象。虽然使用了 `Pre-Normalization` 的模型在训练过程中更加稳定，但是性能却逊色于采用了`Post-Normalization`的模型。

此外，`RMSNorm`与传统`LayerNorm`函数相比，`RMSNorm`在保持训练稳定性和提升模型收敛速度的同时，还大幅提高了计算效率。

`RMSNorm`计算公式为：
$$RMSNorm: y =\frac {x} {\sqrt{Mean(x^2)+\epsilon}} *\gamma $$

$$ Mean(x^2)=\frac{1}{N} \sum_{i=1}^N x_i^2$$

RMSNorm 之所以能更高效，是因为其创造者发现 LayerNorm 的优势在于 rescaling invariance（译者注：指的是归一化过程能够适应输入数据的缩放，使得网络对这种缩放不敏感。），而非 recentering invariance（译者注：如果输入数据的均值发生了变化，但数据的分布形状和范围保持不变，那么具有 recentering invariance 的算法或函数的输出应该不受影响。）。基于这一发现，他们省略了归一化过程中的均值计算，使得算法更加简洁，而效果不减，且运算效率显著提升。
<p align="center">
    <img src="./_img/LayerNorm_comp_RMSNorm.png" width="80%"/> <br>
    层归一化（LayerNorm）与均方根归一化（RMSNorm）之间的方程差异
</p>

更多关于`RMSNorm`的解读见：[RMSNorm.ipynb](../../1-LLM-base-fundamentals/1.3-模型核心组件/归一化方法/RMSNorm.ipynb)

In [2]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
    """

    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
        super().__init__()
        self.scale = nn.Parameter(torch.ones(size))
        self.eps = eps
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # NOTE: the original RMSNorm paper implementation is not equivalent
        # norm_x = x.norm(2, dim=self.dim, keepdim=True)
        # rms_x = norm_x * d_x ** (-1. / 2)
        # x_normed = x / (rms_x + self.eps)
        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
        x_normed = x * torch.rsqrt(norm_x + self.eps)
        return self.scale * x_normed

### SwiGLU 激活函数
`LLaMA`在全连接层中使用了带有`SwiGLU`激活函数的`FFN`，SwiGLU 激活函数的公式为：
$$Swish(x) = x \cdot sigmoid(x)$$
$$SwiGLU(x) = Swish(W^Gx) \odot (W^Ux)$$

这一改变旨在提升模型的性能。两者的核心差异在于：
- ReLU 函数会将所有负数输入直接归零，而正数输入则保持不变。
- 相比之下，**SwiGLU 函数含有一个可学习的参数 β，能够调节函数的插值程度。** 随着 β 值的增大，SwiGLU 的行为将逐渐接近 ReLU。


<p align="center">
    <img src="./_img/SwiGLU.png" width="57%"/>
    <img src="./_img/regeswish.png" width="41.5%"/>
ReLU 与 SwiGLU 在不同 β 值下的行为对比，可以看到当 β 达到 100 时，两者的曲线趋于一致，以及不同激活函数对比图。
</p>


In [3]:
class MLP(nn.Module):
    def __init__(self, config: LLaMAConfig) -> None:
        super().__init__()
        hidden_dim = 4 * config.n_embd
        n_hidden = int(2 * hidden_dim / 3)
        n_hidden = find_multiple(n_hidden, 256)

        self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
        self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
        self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
        x = self.c_proj(x)
        return x

### 旋转式位置编码（Rotary Positional Embeddings）
在位置编码上，`LLaMA`使用旋转位置嵌入（Rotary Positional Embeddings，RoPE）来代替原有的绝对位置编码。RoPE 借助了复数的思想，出发点是通过绝对位置编码的方式实现相对位置编码。

关于`RoPE`的详细解读见：[RoPE.ipynb](../..//1-LLM-base-fundamentals/1.3-模型核心组件/位置编码/RoPE.ipynb)

In [4]:
def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor:
    # truncate to support variable sizes
    T = x.size(1)
    rope_cache = rope_cache[:T]

    # cast because the reference does
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )

    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)

## 注意力机制模块
`Llama`中同样采用的是掩码注意力机制，实现上与前面讲的`GPT`中的`CausalSelfAttention`差别不大，主要区别在于这里加入了`RoPECache`，`MaskCache`和`KVCache`技术来加速训练和推理过程。所以这里，看这段代码主要是研究`Cache`技术是如何运用的。

In [5]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config: LLaMAConfig) -> None:
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.block_size = config.block_size

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        mask: MaskCache,
        max_seq_length: int,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        head_size = C // self.n_head
        k = k.view(B, T, self.n_head, head_size)
        q = q.view(B, T, self.n_head, head_size)
        v = v.view(B, T, self.n_head, head_size)

        q = apply_rope(q, rope)
        k = apply_rope(k, rope)

        k = k.transpose(1, 2)  # (B, nh, T, hs)
        q = q.transpose(1, 2)  # (B, nh, T, hs)
        v = v.transpose(1, 2)  # (B, nh, T, hs)

        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            # check if reached token limit
            if input_pos[-1] >= max_seq_length:
                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
                # shift 1 position to the left
                cache_k = torch.roll(cache_k, -1, dims=2)
                cache_v = torch.roll(cache_v, -1, dims=2)
            k = cache_k.index_copy(2, input_pos, k)
            v = cache_v.index_copy(2, input_pos, v)
            kv_cache = k, v

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        #  att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        #  att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf'))
        #  att = F.softmax(att, dim=-1)
        #  y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        # efficient attention using Flash Attention CUDA kernels
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)

        return y, kv_cache

In [6]:
def build_rope_cache(
    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
) -> RoPECache:
    """Enhanced Transformer with Rotary Position Embedding.

    Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
    transformers/rope/__init__.py. MIT License:
    https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
    """
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

    # Calculate the product of position index and $\theta_i$
    idx_theta = torch.outer(seq_idx, theta).float()

    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.half()
    return cache

## Block
下面是`Llama`中一个模块的构建代码，也就是一个`DecoderLayer`层的实现。

In [7]:
class Block(nn.Module):
    def __init__(self, config: LLaMAConfig) -> None:
        super().__init__()
        self.rms_1 = RMSNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.rms_2 = RMSNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(
        self,
        x: torch.Tensor,
        rope: RoPECache,
        mask: MaskCache,
        max_seq_length: int,
        input_pos: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
        x = x + h
        x = x + self.mlp(self.rms_2(x))
        return x, new_kv_cache

## LLaMA
下面是`LLaMA`模型的实现代码。

In [8]:

class LLaMA(nn.Module):
    def __init__(self, config: LLaMAConfig) -> None:
        super().__init__()
        assert config.padded_vocab_size is not None
        self.config = config

        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
                ln_f=RMSNorm(config.n_embd),
            )
        )

        self.rope_cache: Optional[RoPECache] = None
        self.mask_cache: Optional[MaskCache] = None
        self.kv_caches: List[KVCache] = []

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))

    def forward(
        self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]:
        B, T = idx.size()

        block_size = self.config.block_size
        if max_seq_length is None:
            max_seq_length = block_size
        assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
        assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
        assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"

        if self.rope_cache is None:
            self.rope_cache = self.build_rope_cache(idx)
        if self.mask_cache is None:
            self.mask_cache = self.build_mask_cache(idx)

        if input_pos is not None:
            rope = self.rope_cache.index_select(0, input_pos)
            mask = self.mask_cache.index_select(2, input_pos)
            mask = mask[:, :, :, :max_seq_length]
        else:
            rope = self.rope_cache[:T]
            mask = self.mask_cache[:, :, :T, :T]

        # forward the model itself
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)

        if input_pos is None:  # proxy for use_cache=False
            for block in self.transformer.h:
                x, _ = block(x, rope, mask, max_seq_length)
        else:
            if not self.kv_caches:
                head_size = self.config.n_embd // self.config.n_head
                cache_shape = (B, self.config.n_head, max_seq_length, head_size)
                self.kv_caches = [
                    (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
                    for _ in range(self.config.n_layer)
                ]
            for i, block in enumerate(self.transformer.h):
                x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])

        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)  # (b, t, vocab_size)

        return logits

    @classmethod
    def from_name(cls, name: str) -> Self:
        return cls(LLaMAConfig.from_name(name))

    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
        return build_rope_cache(
            seq_len=self.config.block_size,
            n_elem=self.config.n_embd // self.config.n_head,
            dtype=idx.dtype,
            device=idx.device,
        )

    def build_mask_cache(self, idx: torch.Tensor) -> MaskCache:
        ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
        return torch.tril(ones).unsqueeze(0).unsqueeze(0)

    def reset_cache(self) -> None:
        self.kv_caches.clear()
        if self.mask_cache.device.type == "xla":
            # https://github.com/Lightning-AI/lit-parrot/pull/83#issuecomment-1558150179
            self.rope_cache = None
            self.mask_cache = None

## 运行测试
下面先简单运行一下`LLaMA`模型。

In [9]:
block_size=1024
vocab_size=32000

加载模型配置，由于`7B`及以上的模型层数较多，初始化和推理的速度都比较慢，这里自定义一个`0B`模型，只有2层。

In [10]:
config = LLaMAConfig.from_name("0B") # n_layer=32, n_head=32, n_embd=4096
config.block_size = block_size
print("7B config", config)

7B config LLaMAConfig(block_size=1024, vocab_size=32000, padded_vocab_size=32000, n_layer=2, n_head=4, n_embd=128)


In [11]:
model = LLaMA(config)
print(model)

LLaMA(
  (lm_head): Linear(in_features=128, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 128)
    (h): ModuleList(
      (0-1): 2 x Block(
        (rms_1): RMSNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=False)
          (c_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (rms_2): RMSNorm()
        (mlp): MLP(
          (c_fc1): Linear(in_features=128, out_features=512, bias=False)
          (c_fc2): Linear(in_features=128, out_features=512, bias=False)
          (c_proj): Linear(in_features=512, out_features=128, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)


## 测试数据
随机生成一个测试数据，`input`是tokenize之后的向量，每个值代表`token id`。

In [12]:
input = torch.randint(0,vocab_size,(16,1024))
target = torch.randint(0,vocab_size,(16,1024))

print(input.shape)
print(target.shape)
print("batch size:{}, token length{} ".format(input.shape[0],input.shape[1]))

torch.Size([16, 1024])
torch.Size([16, 1024])
batch size:16, token length1024 


进行一次前向推理，并计算loss损失。结果为词表的概率分布，与[GPT](../../2-主流模型架构/2.1-GPT系列/nanoGPT-run.ipynb)中所讲的相同。

In [13]:
logits = model(input)
print(f"logits.shape:{logits.shape}----Vocab size: {config.vocab_size}")
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), ignore_index=-1)
print(f"Loss: {loss.item()}")

logits.shape:torch.Size([16, 1024, 32000])----Vocab size: 32000
Loss: 10.540298461914062


## forward详解
下面逐行详细讲一下`forward`函数的前向推理过程。

输入参数：
- `idx: torch.Tensor`，输入的token张量，维度为`[batch_size, token_length]
- `max_seq_length: Optional[int] = None`：可选的最大序列长度。如果未提供，则使用配置中的 `block_size`作为最大序列长度。
- `input_pos: Optional[torch.Tensor] = None`：可选的位置索引张量。如果提供，形状应为 (token_length,)，用于指定输入序列中每个位置的具体位置索引。



In [28]:
idx = input
max_seq_length = None
input_pos = None

# 重置一下模型，清空cache
model = LLaMA(config)

首先输入的`batch_size=166`和`token_length=1024`。首先会判断输入的token长度是否小于`max_seq_length`。满足关系`T<=max_seq_length<=block_size`。

In [21]:
B, T = idx.size()
print("batch size:{}, token length{} ".format(B,T))

block_size = model.config.block_size
if max_seq_length is None:
    max_seq_length = block_size
assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}"

batch size:16, token length1024 


### 构建RoPEcache和Maskcache

下面是对`build_rope_cache`的逐行解读。
输入参数：
- seq_len (int): 序列长度。
- n_elem (int): 每个位置的元素数量。
- dtype (torch.dtype): 数据类型。
- device (torch.device): 设备（CPU 或 GPU）。
- base (int): 基数，用于计算频率的基数，默认值为 10000。

`theta`计算公式：
$\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$


In [22]:
def build_rope_cache(
    seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
) -> RoPECache:
    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
    # 计算频率 theta。torch.arange(0, n_elem, 2, dtype=dtype, device=device) 生成从 0 到 n_elem 的步长为 2 的序列，然后除以 n_elem，再用 base 的这些值的幂的倒数来计算频率。
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

    # Create position indexes `[0, 1, ..., seq_len - 1]`
    # 创建位置索引 seq_idx，生成从 0 到 seq_len 的序列。
    seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

    # Calculate the product of position index and $\theta_i$
    # 计算旋转位置编码 idx_theta，通过 seq_idx 和 theta 的外积计算得到。
    idx_theta = torch.outer(seq_idx, theta).float()

    # 计算正弦和余弦编码，并将它们沿最后一个维度拼接，得到最终的旋转位置编码缓存 cache。
    cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

    # this is to mimic the behaviour of complex32, else we will get different results
    if dtype in (torch.float16, torch.bfloat16, torch.int8):
        cache = cache.half()
    return cache

`build_mask_cache`中代码解读见: 
- [torch.ones](../../1-LLM-base-fundamentals/1.6-Pytorch常用函数/生成矩阵.ipynb)，
- [torch.tril](../../1-LLM-base-fundamentals/1.6-Pytorch常用函数/矩阵操作.ipynb)，
- [torch.unsqueeze](../../1-LLM-base-fundamentals/1.6-Pytorch常用函数/矩阵操作.ipynb)

In [29]:
# 此时模型的rope_cache和mask_cache都为空
print(f"ropes_cache before init: {model.rope_cache}")
if model.rope_cache is None:
    model.rope_cache = model.build_rope_cache(idx)
if model.mask_cache is None:
    model.mask_cache = model.build_mask_cache(idx)
    print(model.mask_cache.shape)

ropes_cache before init: None
torch.Size([1, 1, 1024, 1024])


### input_pos
> `input_pos` 是一个可选的 `torch.Tensor`，用于指定输入序列中每个位置的具体位置索引。它的作用是允许模型在处理输入序列时使用特定的位置编码，而不是默认的顺序位置编码。
> 在代码中，如果 `input_pos` 不为 `None`，则会使用 `input_pos` 来从 `rope_cache` 和 `mask_cache` 中选择相应的位置编码和掩码。这意味着模型可以根据 `input_pos` 提供的索引来处理输入序列，而不是假设输入序列的位置是连续的。
> 具体来说：
> - `rope = self.rope_cache.index_select(0, input_pos)`：从 `rope_cache` 中选择 input_pos 指定的位置编码。
> - `mask = self.mask_cache.index_select(2, input_pos)`：从 `mask_cache` 中选择 `input_pos` 指定的掩码。
> 这样做的好处是可以灵活地处理输入序列的位置编码，适应不同的输入需求。



In [30]:
if input_pos is not None:
    rope = model.rope_cache.index_select(0, input_pos)
    mask = model.mask_cache.index_select(2, input_pos)
    mask = mask[:, :, :, :max_seq_length]
else:
    rope = model.rope_cache[:T]
    mask = model.mask_cache[:, :, :T, :T]

### embeddings
embedding之后的维度为`[batch_size, token_length, n_embedding]`

In [31]:
x = model.transformer.wte(idx)
print(f"embedding dim: {config.n_embd}")
print(f"shape before embedding :{idx.shape} shape after embedding: {x.shape}")

embedding dim: 128
shape before embedding :torch.Size([16, 1024]) shape after embedding: torch.Size([16, 1024, 128])


### Attention & MLP 计算
接下来是`forward`过程中的核心部分，`Attention`注意力计算和`MLP`计算，这两个计算是每一层都会计算几次。

In [35]:
print(f"block layers: {len(model.transformer.h)}")
if input_pos is None:  # proxy for use_cache=False
    for block in model.transformer.h:
        x, _ = block(x, rope, mask, max_seq_length)
        print("shape after block: ", x.shape)
else:
    if not model.kv_caches:
        head_size = model.config.n_embd // model.config.n_head
        cache_shape = (B, model.config.n_head, max_seq_length, head_size)
        model.kv_caches = [
            (torch.zeros(cache_shape, device=x.device, dtype=x.dtype), torch.zeros(cache_shape, device=x.device, dtype=x.dtype))
            for _ in range(model.config.n_layer)
        ]
    for i, block in enumerate(model.transformer.h):
        x, model.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, model.kv_caches[i])

block layers: 2
shape after block:  torch.Size([16, 1024, 128])
shape after block:  torch.Size([16, 1024, 128])


### 最后的FFN


In [36]:
x = model.transformer.ln_f(x)
print(f"after ln_f shape: {x.shape}")

after ln_f shape: torch.Size([16, 1024, 128])


### 词表概率映射

In [37]:
logits = model.lm_head(x)
print("output logits ",logits.shape)

output logits  torch.Size([16, 1024, 32000])


## KVCache
关于KVCache的详细解读见：
[9-推理加速/KVCache/KVCache.ipynb](../../9-推理加速/KVCache/KVCache.ipynb)