## Multi_Token_Prediction

传统的预训练任务为 `next-token-prediction`， 多词元预测任务(Multi-token-prediction, MTP) 是 DeepSeek-V3 的训练任务，其本质是 `next-next-token-prediction`，此概念来源于一个并行解码技术。V3 仅采用 MTP-loss 做训练，训练完毕后将额外的 lm head 去除


对于预训练文本“拿铁就是牛奶加咖啡”

- ntp: `p7 = p(y_牛|拿铁就是牛奶加)`
- mtp: 任务1: `p7 = p(y_咖|拿铁就是牛奶加)` , 任务2 `p8 = p(y_啡|拿铁就是牛奶加)`, 对于 p8 任务来说预测 “啡” 字是比较大的概率的，能够达到高准确率的话，那么就能在解码中并行预测多个token了。
- mtp basic实现：对于token ‘加’的输出 logits(hidden state) 送到两个任务头。MTP 本质是多任务学习，相较 mtp(nntp) 较 ntp 更难，学习的特征越丰富
- mtp-v3：用 NTP 构建 NNTP 学习任务

这里的 MTP 实际字面意义不严格，更准确说法是：next-i-token-prediction(nitp)， 下 i 个词元预测。

- 下 1 个词元预测：`p7 = p(y_牛|拿铁就是牛奶加)`
- 下 2 个词元预测：`p8 = p(y_啡|拿铁就是牛奶加)`

为了实现这种训练, 本文依次讲解：

1. dummy MTP dataset
2. basic MTP
3. MTP
4. MTP backward 分析

## MTP dataset

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import math

In [2]:
x = torch.randn(2,3)
x[:1] = 0
print(x)
y = torch.roll(x, shifts=-1)
print(y)

tensor([[0.0000, 0.0000, 0.0000],
        [1.4317, 0.3677, 1.2551]])
tensor([[0.0000, 0.0000, 1.4317],
        [0.3677, 1.2551, 0.0000]])


In [3]:
from typing import Dict

IGNORE_INDEX = -100

class MTPDataset(Dataset):
    def __init__(self, x, n=5):
        super().__init__()
        if x is None:
            return 
        self.num_tokens = n
        self.x = x.clone()
        bsz, seq_len = x.shape
        self.y = torch.ones(n, bsz, seq_len, dtype=torch.long) * IGNORE_INDEX
        for i in range(n):
            self.y[i, :, :-i-1] = x[:, i+1:]
            
            
    def __len__(self):
        bsz, seq_len = x.shape
        return bsz

    def __getitem__(self, idx):
        data = {'input_ids': self.x[idx,:],
            'labels':  self.y[:, idx, :],}
        return data

bsz = 2
seq_len = 10
vocab_size = 100
dim = 512
N = 5

x = torch.randint(vocab_size, (bsz, seq_len))
dataset = MTPDataset(x, n=N)
print(dataset[0]['input_ids'])
print(dataset[0]['labels'])

tensor([83, 74, 53,  6,  3,  2, 29, 44, 93, 32])
tensor([[  74,   53,    6,    3,    2,   29,   44,   93,   32, -100],
        [  53,    6,    3,    2,   29,   44,   93,   32, -100, -100],
        [   6,    3,    2,   29,   44,   93,   32, -100, -100, -100],
        [   3,    2,   29,   44,   93,   32, -100, -100, -100, -100],
        [   2,   29,   44,   93,   32, -100, -100, -100, -100, -100]])


## MTP Model

In [4]:
class MTPLanguageModel(nn.Module):
    def __init__(self, dim, vocab_size, num_mtp = 5):
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.num_mtp = num_mtp
        self.embd = nn.Embedding(vocab_size, dim)
        self.w = nn.Linear(dim, dim)
        self.lm_heads = nn.ModuleList(
            [ nn.Linear(dim, vocab_size) for _ in range(self.num_mtp) ]
        )

    def forward(self, x):
        X = self.embd(x)
        hidden_states = self.w(X)

        logits_list = [
            lm_head(hidden_states).unsqueeze(dim = 0) for lm_head in self.lm_heads
        ]

        return torch.cat(logits_list, dim = 0)

model = MTPLanguageModel(dim = dim, vocab_size=vocab_size, num_mtp = N)
print(model)

MTPLanguageModel(
  (embd): Embedding(100, 512)
  (w): Linear(in_features=512, out_features=512, bias=True)
  (lm_heads): ModuleList(
    (0-4): 5 x Linear(in_features=512, out_features=100, bias=True)
  )
)


In [5]:
logits_list = model(x)
# print(len(logits_list))
print(logits_list.shape)

torch.Size([5, 2, 10, 100])


## MTP Loss

In [6]:
def mtp_loss(X, labels, weight):
    n, bsz, seq_len, vocab_size = X.shape
    n, bsz, seq_len = labels.shape
    loss_fn = nn.CrossEntropyLoss(ignore_index = IGNORE_INDEX)
    loss_mtp = torch.zeros(n)

    
    for i in range(n):
        tmp_loss = loss_fn( X[i].view(bsz*seq_len, vocab_size), labels[i].view(bsz*seq_len))
        loss_mtp[i] = tmp_loss

    loss = (loss_mtp * weight).mean()
    return loss

weight = torch.tensor( [1.0, 0.5, 0.4, 0.2, 0.1])
loss = mtp_loss(logits_list, dataset.y, weight)
print(loss)

tensor(2.0573, grad_fn=<MeanBackward0>)


## Next-Token Prediction

In [7]:
# head 0, next token prediction
next_token = torch.argmax(logits_list[0, 0, -1, :], dim = -1) # heads:0, bsz:0, seq_id:-1, dim 
print(next_token)

# head-0~N, multi token token prediction
next_next_token = torch.argmax(logits_list[:, 0, -1, :], dim = -1) # heads:0, bsz:0, seq_id:-1, dim 
print(next_next_token)

tensor(31)
tensor([31, 15, 71, 52, 38])


## DeepSeek MTP

1. DeepSeek MTP 做法中，MTP 头 share 主体模型的 lm head
2. MTP 组件包含：embd(shared), norm, linear, transformer-block, lm_head(shared)
3. MTP 之间的特征是串行递归关系（可以理解为传 latent 特征，而非具体 token)
4. 每头的输入为 `[t1,t2,t3,t4]`, `[t2,t3,t4,t5]`, 这与我们上述的输入形式有区别
5. v3 的 MTP 每个头训练仍然是 next-token-prediction 形式

In [8]:
class MTPModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.RMSNorm_pre = nn.Linear(dim, dim)
        self.RMSNorm_cur = nn.Linear(dim, dim)
        self.Proj = nn.Linear(dim*2, dim)
        self.Transformer_block = nn.Linear(dim, dim)
        
    def forward(self, X_embd, H_pre):
        X_embd, H_pre = self.RMSNorm_cur(X_embd), self.RMSNorm_pre(H_pre)
        X = torch.cat((X_embd, H_pre), dim = -1)
        X = self.Proj(X)
        X = self.Transformer_block(X)
        return X

class DeepSeekMTP(nn.Module):
    def __init__(self, dim, vocab_size, num_mtp=5):
        '''
        lm_head 为主体输出头
        num_mtp 额外增加 MTP 输出头模块数量
        总输出头模块：1+num_mtp
        '''
        super().__init__()
        self.dim = dim
        self.vocab_size = vocab_size
        self.num_mtp = num_mtp
        self.embd = nn.Embedding(vocab_size, dim)
        self.w = nn.Linear(dim, dim)
        self.lm_head = nn.Linear(dim, vocab_size)
        
        self.mtp_heads = nn.ModuleList(
            [ MTPModule(dim) for _ in range(self.num_mtp) ]
        )

    def forward(self, x):
        print(x)
        bsz, seq_len = x.shape
        X = self.embd(x)
        hidden_states = self.w(X[:, -self.num_mtp:, :])
        lm_logits = self.lm_head(hidden_states) # bsz, len, vocab_size
        mtp_logits = torch.randn(self.num_mtp, 
                                 bsz, 
                                 seq_len - self.num_mtp, 
                                 vocab_size)

        # 以下循环为递归调用
        for i in range(self.num_mtp):
            X_cur = X[:, i+1: i+1+seq_len-self.num_mtp, :]
            # 讨论: 分析 detach 作用
            hidden_states_i = self.mtp_heads[i](X_cur, hidden_states.detach())
            mtp_logits_i = self.lm_head(hidden_states_i)
            mtp_logits[i] = mtp_logits_i
            hidden_states = hidden_states_i

        return lm_logits, mtp_logits

        
model = DeepSeekMTP(dim = dim, vocab_size = vocab_size, num_mtp = N)
print(model)

DeepSeekMTP(
  (embd): Embedding(100, 512)
  (w): Linear(in_features=512, out_features=512, bias=True)
  (lm_head): Linear(in_features=512, out_features=100, bias=True)
  (mtp_heads): ModuleList(
    (0-4): 5 x MTPModule(
      (RMSNorm_pre): Linear(in_features=512, out_features=512, bias=True)
      (RMSNorm_cur): Linear(in_features=512, out_features=512, bias=True)
      (Proj): Linear(in_features=1024, out_features=512, bias=True)
      (Transformer_block): Linear(in_features=512, out_features=512, bias=True)
    )
  )
)


In [9]:
class LMDataset(Dataset):
    def __init__(self, x, n=5):
        super().__init__()
        if x is None:
            return 
        self.num_tokens = n
        self.x = x.clone()
        x[:, 0] = IGNORE_INDEX
        self.y = torch.roll(x, shifts = -1)
            
    def __len__(self):
        bsz, seq_len = x.shape
        return bsz

    def __getitem__(self, idx):
        data = {'input_ids': self.x[idx, :],
                'labels':  self.y[idx, :],}
        return data

x = torch.randint(vocab_size, (bsz, seq_len))
dataset = LMDataset(x.clone())
print(dataset[:]['input_ids'])
print(dataset[:]['labels'])

tensor([[70, 86, 43, 19, 24, 58, 38,  6, 77, 84],
        [28, 37, 30, 61, 54, 40, 32,  4, 29, 17]])
tensor([[  86,   43,   19,   24,   58,   38,    6,   77,   84, -100],
        [  37,   30,   61,   54,   40,   32,    4,   29,   17, -100]])


In [10]:
input_ids = dataset[:]['input_ids']
lm_logits, mtp_logits = model(input_ids)
print(lm_logits.shape)
print(mtp_logits.shape)

tensor([[70, 86, 43, 19, 24, 58, 38,  6, 77, 84],
        [28, 37, 30, 61, 54, 40, 32,  4, 29, 17]])
torch.Size([2, 5, 100])
torch.Size([5, 2, 5, 100])


In [11]:
def deepseek_v3_mtp_loss(lm_logits,
                         mtp_logits, 
                         y,
                        lam=0.1):

    N, bsz, seq_len, vocab_size = mtp_logits.shape
    loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

    
    loss_lm_head = loss_fn(lm_logits.view(bsz*seq_len,vocab_size), 
                           y[:,:-N].reshape(bsz*seq_len))

    loss_mtp_head = torch.zeros(N)
    for i in range(N):
        loss_mtp_head[i] = loss_fn(mtp_logits[i,:,:,:].view(bsz*seq_len,vocab_size),
                                   y[:,i:-N+i].reshape(bsz*seq_len))
                                
    loss = loss_lm_head + lam*loss_mtp_head.mean()
    return loss, loss_lm_head, loss_mtp_head

loss, loss_lm,loss_mtp = deepseek_v3_mtp_loss(lm_logits, 
                     mtp_logits, 
                     y= dataset[:]['labels'],
                     lam = 0.1)
print(loss)
print(loss_lm)
print(loss_mtp)                                        

tensor(5.1376, grad_fn=<AddBackward0>)
tensor(4.6760, grad_fn=<NllLossBackward0>)
tensor([4.5976, 4.6188, 4.6195, 4.6066, 4.6396], grad_fn=<CopySlices>)


## MTP 特征分析

- Proj前cat [next_predcition_feature, embedding], 此向量包含 t1 hidden 和当前 embd, 其做法是 cat, 在 RNN 网络中通常为加法: h(t1)+x(t2)
- TransformerBlock 输出后为 t2 预测 next token(t3) 的特征
- 虽然 MTP 类损失是多任务损失，其本质仍是 next-token-prediction， 而非 next-next-token-prediction
- 准确来说 v3 是： reccurent-next-token-prediction
- mtp moudle 的输入去除一些首 token, 输入信息不完整
- inference 角度，MTP 的 NNTP 成本（单层解码块）原低于网络主体 forward（多层解码块）

## 梯度分析

1. 每个 mtp 接收上一块数据的 Last hidden，那么其 backwad 计算路径应当为 $MTP_N,MTP_{N-1},\ldots, MTP_1$
2. 如果将 上一块的输出 detach, 那么每个块可以根据局部 mtp loss backward，此时 mtp module 内部参数更新，以及整体lmhead、embd更新

## V3-MTP Next-Next-Token-Prediction

2. 试讨论以下的 inference 程序是否正确？
3. MTP-basic 与 V3-MTP NNTP 差异在哪？

Training 阶段

 - 主体模型输入 [t1,t2,t3,t4] 预测 [t2,t3,t4,t5]
 - MTP1 输入  [t2,t3,t4,t5] 预测 [t3,t4,t5,t6]

对于 training 阶段, t5 是提前知道的，在 inference 阶段时，需要 lm-head 取 argmax 解出 t5, 才能计算 t6

In [14]:
# Training
# 主体模型输入 [t1,t2,t3,t4] 预测 [t2,t3,t4,t5]

ntp = torch.argmax(lm_logits[:, -1, :], dim = -1).unsqueeze(1)
nntp = torch.argmax(mtp_logits[:, :, -1, :], dim = -1).t()
pred = torch.cat((ntp,nntp), dim = 1)
print(pred[0])

tensor([97, 55, 77,  2, 73, 31])


In [42]:
# Inference

x = dataset[:]['input_ids'].clone()
bsz, seq_len = x.shape
X = model.embd(x)
hidden_states = model.w(X) # input全部输入
lm_logits = model.lm_head(hidden_states) # bsz, len, vocab_size
next_token = torch.argmax(lm_logits[:, -1, :], dim = -1).unsqueeze(dim = 1)
print(x[0,:])
x = torch.cat((x, next_token), dim = 1)
print(x[0,:])

# # 以下循环为递归调用
for i in range(model.num_mtp):
    # input
    X = model.embd(x[:, i+1:])

    # output
    hidden_states_i = model.mtp_heads[i](X, hidden_states)
    mtp_logits_i = model.lm_head(hidden_states_i)
    hidden_states = hidden_states_i

    # next-token-prediction
    next_token = torch.argmax(mtp_logits_i[:, -1, :], dim = -1).unsqueeze(dim = 1)
    x = torch.cat((x, next_token), dim = 1)
    print(x[0,:])

tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84])
tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84, 97])
tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84, 97, 49])
tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84, 97, 49, 91])
tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84, 97, 49, 91, 14])
tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84, 97, 49, 91, 14, 28])
tensor([70, 86, 43, 19, 24, 58, 38,  6, 77, 84, 97, 49, 91, 14, 28, 93])


讨论：

1. 从 LM+MTP 整体来看输入 [t1,t2,t3] 预测 [t4,t5,t6,...], 但是分别从 lm_head, mtp-i 看其做的仍然是 NTP
3. Inference 时，MTP 内部有 Transformer-Block 是否需要 kv-cache？

## 总结

1. basic mtp 真正去做 next-next-token-prediciton（NNTP）， 其行为为隔空预测
2. deepseek-v3 mtp 训练做 next-token-prediction（NTP）。 释放的v3砍掉mtp头
4. V3 并没有讨论 NNTP 的性能指标，其涉及到“并行解码”技术
5. 讨论：为什么 MTP 能提升模型预测能力？V3 使用 mtp 训练，加强了模型主体的 lm_head 预测, 加强的原理可看成在时许上引入了 latent 时序 feature。 那么 Attn 做序列加权组合， FFN 做特征表示， MoE做集成学习，LM_head做NTP，而带 MTP-LM_head做时序特征预测表示，

## Reference

[DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437v1)

[Better & Faster Large Language Models via Multi-token Prediction](https://arxiv.org/abs/2404.19737)