# DeepSeek-V3 MoE

DeepSeek-V3 671B 是一个大型的 MoE 模型，其 MoE 模型特点为：

1. 包含路由专家(router experts)和共享专家(shared experts), 前者是 sparse 的
2. 增加 bias 来帮助模型学习负载均衡
3. 序列级别负载均衡 loss, 计算方式跟随 Switch Transformer 模式

MoE 公式

\begin{align}
    \mathbf{h}_{t}^{\prime} &= \mathbf{u}_{t} + \sum_{i=1}^{N_{s}} {\text{FFN}^{(s)}_{i}\left( \mathbf{u}_{t} \right)} + \sum_{i=1}^{N_r} {g_{i,t} \text{FFN}^{(r)}_{i}\left( \mathbf{u}_{t} \right)}, \\
    g_{i,t} &= \frac{g^{\prime}_{i,t}}{\sum_{j=1}^{N_r} g^{\prime}_{j,t}}, \\
    g^{\prime}_{i,t} &= \begin{cases} 
    s_{i,t}, & s_{i,t} \in \text{Topk} (\{ s_{j, t} | 1 \leq j \leq N_r \}, K_{r}), \\
    0, & \text{otherwise}, 
    \end{cases} \\
    s_{i,t} &= \text{Sigmoid} \left( {\mathbf{u}_{t}}^{T} \mathbf{e}_{i} \right),
\end{align}


可学习的负载均衡偏置

\begin{align}
    g^{\prime}_{i,t} &= \begin{cases} 
    s_{i,t}, & s_{i,t} + b_i \in \operatorname{Topk} (\{ s_{j, t} + b_j | 1 \leq j \leq N_r \}, K_{r}), \\
    0, & \text{otherwise}.
    \end{cases}
\end{align}


序列级别负载均衡

\begin{align}
    \mathcal{L}_{\mathrm{Bal}} &= \alpha \sum_{i=1}^{N_r}{f_i P_i}, \\
    % f_i = \frac{N_r}{K_r T} \sum_{t=1}^{T} \mathcal{1}( \text{Expert $i$ } & \text{belongs to the Top-$K_r$ set for Token $t$} ), \\
    f_i = \frac{N_r}{K_r T} \sum_{t=1}^{T} \mathcal{1} & \left( s_{i,t} \in \operatorname{Topk} ( \{ s_{j, t} | 1 \leq j \leq N_r \}, K_{r} ) \right), \\
    s^{\prime}_{i,t} &= \frac{s_{i,t}}{\sum_{j=1}^{N_r} s_{j,t}}, \\
    P_i &= \frac{1}{T} \sum_{t=1}^{T}{s^{\prime}_{i,t}},
\end{align}

## DeepSeek-V3 MoE basic 

实现一个简易版的模型, 了解其主要工作模块

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x10b78d9f0>

In [9]:
class BasicDeepSeekV3MoE(nn.Module):
    def __init__(self, dim, 
                 expert_nums = 8, 
                 top_k = 2, 
                 shared_expert_nums = 4):
        super().__init__()

        # Route Experts: 
        self.expert_nums = expert_nums
        self.k = top_k
        self.experts = nn.ModuleList(
            [ nn.Linear(dim, dim) for _ in range(self.expert_nums)]
        )
        self.w_gate = nn.Linear(dim, self.expert_nums)
        
        # Shared Experts: 
        self.shared_expert_nums = shared_expert_nums
        self.shared_experts = nn.ModuleList(
            [ nn.Linear(dim, dim) for _ in range(self.shared_expert_nums)]
        )

    def forward(self, x):
        # 整体前向计算
        y_route, weight, idx = self.forward_route(x) 
        y_shared = self.forward_shared(x) 
        y = x + y_route + y_shared
        load_loss = self.load_balance_sequence_wise(weight, idx)
        return y, load_loss
    
    def forward_route(self, x):
        # 默认稀疏选择专家 0
        return self.experts[0](x), None, None

    def forward_shared(self, x):
        y = torch.zeros_like(x)
        for i in range(self.shared_expert_nums):
            y += self.shared_experts[i](x) # not gate
        return y
        
    def load_balance_sequence_wise(self, weight, idx):
        return 1.0

In [10]:
# config
bs = 2
seq_len = 8
dim = 64
expert_nums = 20
expert_shared_nums = 4
topk = 2
x = torch.randn(bs, seq_len, dim)
label = torch.randn(bs, seq_len, dim)

# model
model = BasicDeepSeekV3MoE(dim, expert_nums=expert_nums, top_k = topk, shared_expert_nums = expert_shared_nums)
print(model)

DeepSeekV3MoE(
  (experts): ModuleList(
    (0-19): 20 x Linear(in_features=64, out_features=64, bias=True)
  )
  (w_gate): Linear(in_features=64, out_features=20, bias=True)
  (shared_experts): ModuleList(
    (0-3): 4 x Linear(in_features=64, out_features=64, bias=True)
  )
)


In [11]:
print(x.shape)
y, load_loss = model(x)
print(y.shape)
print(load_loss)

torch.Size([2, 8, 64])
torch.Size([2, 8, 64])
1.0


## Expert 

在公式里标记为 FFN 作为一个 expert, 实现一个 SwiGLU

In [20]:
class Expert(nn.Module): # expert
    def __init__(self, dim):
        super().__init__()
        self.dim_in = dim
        self.dim_out = self.dim_in * 8 // 3
        self.w1 = nn.Linear(self.dim_in, self.dim_out , bias = False)
        self.w_act = nn.Linear(self.dim_in, self.dim_out, bias = False) 
        self.w2 = nn.Linear(self.dim_out, self.dim_in, bias = False)  
        self.SiLU = nn.SiLU()
    
    def forward(self, x):
        h = self.w1(x)
        h_act = self.w_act(x)
        h_act_up = self.SiLU(h_act) * h
        output = self.w2(h_act_up)
        return output

## MoE

\begin{align}
    \mathbf{h}_{t}^{\prime} &= \mathbf{u}_{t} + \sum_{i=1}^{N_{s}} {\operatorname{FFN}^{(s)}_{i}\left( \mathbf{u}_{t} \right)} + \sum_{i=1}^{N_r} {g_{i,t} \operatorname{FFN}^{(r)}_{i}\left( \mathbf{u}_{t} \right)}, \\
    g_{i,t} &= \frac{g^{\prime}_{i,t}}{\sum_{j=1}^{N_r} g^{\prime}_{j,t}}, \\
    g^{\prime}_{i,t} &= \begin{cases} 
    s_{i,t}, & s_{i,t} \in \operatorname{Topk} (\{ s_{j, t} | 1 \leq j \leq N_r \}, K_{r}), \\
    0, & \text{otherwise}, 
    \end{cases} \\
    s_{i,t} &= \operatorname{Sigmoid} \left( {\mathbf{u}_{t}}^{T} \mathbf{e}_{i} \right),
\end{align}

In [35]:
class DeepSeekV3MoE(nn.Module):
    def __init__(self, dim, expert_nums = 8, top_k = 2, shared_expert_nums = 4):
        super().__init__()

        # Route Experts: 
        self.expert_nums = expert_nums
        self.k = top_k
        self.experts = nn.ModuleList([ Expert(dim) for _ in range(self.expert_nums)])
        self.w_gate = nn.Linear(dim, self.expert_nums)

        # Auxiliary-Loss-Free Load Balancing
        self.bias = torch.nn.Parameter(torch.zeros( expert_nums )) 
        self.alpha = 0.001
        
        # Shared Experts: 
        self.shared_expert_nums = shared_expert_nums
        self.shared_experts = nn.ModuleList([ Expert(dim) for _ in range(self.shared_expert_nums)])

    def forward(self, x):
        y_route, weight, idx = self.forward_route(x) 
        y_shared = self.forward_shared(x) 
        y = x + y_route + y_shared
        # load_loss = self.load_balance_sequence_wise(weight, idx)
        # load_loss = 1.0
        return y, weight, idx
    
    def forward_route(self, x):

        # gate 处理
        g = self.w_gate(x)
        g = F.sigmoid(g) # sigmoid 代替 softmax
        weight, idx = torch.topk(g, k = self.k, dim = -1) 
        weight_norm = weight / (weight.sum(dim=-1, keepdim= True) + 1e-20)

        # dispatch
        expert_results = [None] * self.expert_nums
        for i in range(self.expert_nums):
            cur_pos = torch.where(idx == i) 
            x_select = x[cur_pos[0], cur_pos[1], :] 
            if x_select.shape[0] > 0: 
                expert_results[i] = self.experts[i](x_select)
        
        # combine
        y_result = torch.zeros_like(x) 
        for i in range(self.expert_nums):
            cur_pos = torch.where(idx == i) 
            if expert_results[i] != None:
                y_result[cur_pos[0], cur_pos[1], :] += 
                expert_results[i] * weight_norm[cur_pos[0], cur_pos[1], cur_pos[2]].unsqueeze(-1)
        
        return y_result, g, idx 

    def forward_shared(self, x):
        y = torch.zeros_like(x)
        for i in range(self.shared_expert_nums):
            y += self.shared_experts[i](x) # not gate
        return y

# model
model = DeepSeekV3MoE(dim, expert_nums=expert_nums, top_k = topk, shared_expert_nums = expert_shared_nums)
print(model)

DeepSeekV3MoE(
  (experts): ModuleList(
    (0-19): 20 x Expert(
      (w1): Linear(in_features=64, out_features=170, bias=False)
      (w_act): Linear(in_features=64, out_features=170, bias=False)
      (w2): Linear(in_features=170, out_features=64, bias=False)
      (SiLU): SiLU()
    )
  )
  (w_gate): Linear(in_features=64, out_features=20, bias=True)
  (shared_experts): ModuleList(
    (0-3): 4 x Expert(
      (w1): Linear(in_features=64, out_features=170, bias=False)
      (w_act): Linear(in_features=64, out_features=170, bias=False)
      (w2): Linear(in_features=170, out_features=64, bias=False)
      (SiLU): SiLU()
    )
  )
)


In [36]:
print(x.shape)
y, weight, idx = model(x)
print(y.shape)
print(load_loss)

torch.Size([2, 8, 64])
torch.Size([2, 8, 64])
1.0


## Load Balance

In [39]:
def load_balance_sequence_wise(model, s, idx):
    '''
    sequence-wise 是为了在 inference 阶段, prefill 时能够均匀分散在各 GPU 中。
    更直观的理解是当 bsz 为 1 时, prefill 可以直接受益。
    '''
    Nr = model.expert_nums # n routes expert
    bs, seq_len, dim = s.shape # seq : pad token 要去除，或者增加 mask, 避免计入 loss 里

    l_lab = torch.zeros(1)

    for k in range(bs):

        # Compute fi
        fi = torch.zeros(Nr)
        # pi = torch.zeros(self.expert_nums)

        # seq-count 
        seq_expert_count = torch.zeros(Nr)
        idx_seq = idx[k,:,:]

        for i in range(model.expert_nums):
            seq_expert_count[i] = torch.where(idx_seq == i)[1].numel()
        
        fi = Nr / (model.k * seq_len) * seq_expert_count

        # Compute pi
        s_seq = s[k, :, :]
        si_ = s / s.sum(dim = -1, keepdim = True)
        pi = si_.sum(dim = 0) / seq_len

        l_bal_seq = (fi * pi).sum() / seq_len # seq_len_no_pad or use mask
        l_lab += l_bal_seq
        
    l_lab = model.alpha * l_lab
    return l_lab

loss = load_balance_sequence_wise(model, weight, idx)
print(loss)

tensor([0.0005], grad_fn=<MulBackward0>)


## 修正门控权重

基于 $s_{j, t} + b_j$ 求 top-k 专家，但是权重值为未修正的 $s_{i,t}$

\begin{align}
    g^{\prime}_{i,t} &= \begin{cases} 
    s_{i,t}, & s_{i,t} + b_i \in \operatorname{Topk} (\{ s_{j, t} + b_j | 1 \leq j \leq N_r \}, K_{r}), \\
    0, & \text{otherwise}.
    \end{cases}
\end{align}

In [83]:
# fix gate 处理

# bias = model.bias
bias = torch.randn(model.expert_nums)

x_view = x.reshape(bs*seq_len, dim)

s = F.sigmoid(model.w_gate(x_view)) # sigmoid 代替 softmax
sb = s + bias[None, :] # bias 是可学习参数

# weight, idx = torch.topk(s, k = model.k, dim = -1) 
sb_weight, idx = torch.topk(sb, k = model.k, dim = -1) 

weight = torch.zeros(bs*seq_len, model.k)
for i in range(bs*seq_len):
    weight[i, :] = s[i, idx[i, :]]

weight_norm = weight / (weight.sum(dim=-1, keepdim= True) + 1e-20) 
sb_weight_norm = sb_weight / (sb_weight.sum(dim=-1, keepdim= True) + 1e-20) # ignore

print(sb_weight_norm - weight_norm)

weight_norm = weight_norm.reshape(bs, seq_len, model.k)
print(weight_norm.shape)

tensor([[-0.1617,  0.1617],
        [-0.0920,  0.0920],
        [-0.1056,  0.1056],
        [ 0.0563, -0.0563],
        [ 0.1383, -0.1383],
        [ 0.0130, -0.0130],
        [-0.0878,  0.0878],
        [-0.1019,  0.1019],
        [ 0.0464, -0.0464],
        [ 0.0575, -0.0575],
        [ 0.0681, -0.0681],
        [ 0.0763, -0.0763],
        [-0.1013,  0.1013],
        [-0.1906,  0.1906],
        [ 0.0014, -0.0014],
        [ 0.0119, -0.0119]], grad_fn=<SubBackward0>)
torch.Size([2, 8, 2])


## 结论

1. DeepSeekV3 MoE 由路由专家、共享专家、负载均衡偏置参数、序列级负载均衡组成
2. 序列级负载均衡目的是 prefill 各序列均衡，当 bsz = 1 时，其均衡效用显著

讨论：

1. batch、seq、layers 之间的整体 load-balance 计算形式
2. sigmoid 代替 softmax 原因
3. 讨论 sMoE、SwithTransformer、V3-MoE之间的 稀疏归一化 差异
4. 讨论 Mixtral 8x7B load balance 是 MoE 内计算， 还是基于所有 MoE 层专家进行计算？ 后者目的是什么？
5. V3 稀疏 MoE 专家数量为 256 个，如果进行分布式专家并行部署，是否一定需要 256 个GPU以上？
6. 讨论 shared expert 与 router expert 二者集成学习有什么差异
7. 从 Mixtral 8x7B, specilist 类问题是否会集中分配给特定专家吗？如果不会是什么原因造成的？