# MoE

混合专家(Mixture of Experts, MoE) 是一种集成学习方法，将输入数据分配给特定专家处理，从而提高特征学习的效率。MoE包含：

1. 门控（Gated）：输入数据，门控网络输出分配的专家权重。
2. 专家（Experts）：各专家处理数据。在神经网络中，专家可以是特征学习器 FFN/MLP/Linear，专家组指多个特征学习器。在 Transformer 类模型，通常由多个 FFN 与 1个门控网络构成。
3. 混合（Mixture）：将各专家输出加权组合，得到 MoE 输出

本 notebook 讨论 MoE、SMoE、balance loss，另外引入问题

- FFN 对 Transformer 网络学习的重要性，为何要在 FFN 中想办法扩展参数量（模型容量），而不在 Attention 扩展？
- MoE 的学习的本质是什么？是否符合直觉，专家能够处理特定领域特征
- 门控对象是 token-level 还是 sentence-level 的？
- 引入 MoE 是否显著增加计算量？是否各专家输出特征会用冗余？
- 稀疏门控是否导致 MoE 特征不稳定？
- 稀疏门控可导吗？不可导的情况下如何训练？
- 稀疏 MoE 训练难点是什么？有哪些策略缓解？

## 关键 MoE 工作

1. MoE，1991: Adaptive mixtures of local experts
2. SMoE，2017: Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer
3. EP，2021: GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
4. Loss，2022: Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity,
5. MoE-LLM，2023.12: Mixtral of Experts
6. DeepSeek-V1，2024.01 
8. DeepSeek-V2，2024.05
9. DeepSeek-V3，2024.12

## MoE



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x10bc0d070>

In [3]:
class Expert(nn.Module):
    def __init__(self, dim=512):
        '''
        在 LLM 中, FFN 功能层作为 Expert
        '''
        super().__init__()
        self.dim = dim
        self.w = nn.Linear(self.dim, self.dim)
    def forward(self, x):
        return self.w(x)

class MoEBasic(nn.Module):
    def __init__(self, dim=512, num_experts=8):
        super().__init__()
        self.dim = dim
        self.num_experts = num_experts
        self.experts = nn.ModuleList(
            [ Expert(dim = self.dim) for _ in range(self.num_experts) ]
        )
        self.gate = nn.Linear(self.dim, self.num_experts)

    def forward(self, x):
        '''
        x: bsz, seq_len, dim
        '''
        weight = self.gate(x)
        weight = F.softmax(weight, dim = -1) # bsz, seq_len, num_experts

        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            output += weight[...,i].unsqueeze(-1) * expert(x)

        return output

dim = 512
num_experts = 8
moe = MoEBasic(dim=dim,  num_experts=num_experts)
print(moe)

MoEBasic(
  (experts): ModuleList(
    (0-7): 8 x Expert(
      (w): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (gate): Linear(in_features=512, out_features=8, bias=True)
)


In [4]:
bsz = 2
seq_len = 3

X = torch.randn(bsz, seq_len, dim)
Y = moe(X)
print(Y.shape)

torch.Size([2, 3, 512])


In [5]:
# dim broadcast
x = torch.ones(1,2,3)
weight = torch.randn(1,2,1) 
print(weight)
x*weight 

tensor([[[ 0.4706],
         [-0.2382]]])


tensor([[[ 0.4706,  0.4706,  0.4706],
         [-0.2382, -0.2382, -0.2382]]])

上述的 MoE 是最基础的版本，各专家协作进行特征表示。 有 “专家竞争的loss” 版本。

后续的改进均以专家合作为主（加权特征），给定专家数量为N

- MoE: 筛选 top-N 专家
- sMoE：筛选 top-k 专家, $k\in[1,N]$ 
- Switch-Transformer：筛选 top-1 专家


## SMoE

选择 top-k 权重专家, 对于非top-k专家的 logits 设置为 $-\infty$, 如下：


$$
G(x) = \texttt{Softmax}(\texttt{KeepTopK}(H(x), k))
$$

$$
\texttt{KeepTopK}(v, k)_i = \begin{cases}
            v_i & \text{if $v_i$ is in the top $k$ elements of $v$.} \\
            -\infty & \text{otherwise.}
        \end{cases}
$$

本 notebook 实现方法为， 

1. logist 直接算 softmax： p1,p2,...,p8
2. 选出概率top-2 专家，如 p2, p8
3. 将两者做归一化 p2 / (p2+p8), p8/(p2+p8)

In [6]:
data = torch.randn(3,5)
print(data)
v, idx = torch.topk(data, k=2, dim = -1)
print(v, idx)

tensor([[ 0.7493, -1.3844, -2.1473,  0.1755,  0.7047],
        [ 0.9629, -0.1419, -0.5430,  0.2240, -0.2433],
        [ 1.9174,  0.9862,  1.0908, -1.2830,  0.0711]])
tensor([[0.7493, 0.7047],
        [0.9629, 0.2240],
        [1.9174, 1.0908]]) tensor([[0, 4],
        [0, 3],
        [0, 2]])


In [9]:
class MoESparse(MoEBasic):
    def __init__(self, dim=512, num_experts=8, topk = 2):
        super().__init__()
        self.topk = 2

    def forward(self, x):
        bsz, seq_len, _ = x.shape
        weight = self.gate(x)
        weight = F.softmax(weight, dim = -1) # bsz, seq_len, num_experts

        # process top-k index
        v, idx = torch.topk(weight, dim = -1, k=self.topk)
        print('token选择的专家 id 是不同的:', idx)
        
        weight_sparse = torch.zeros_like(weight)
        # # TODO: 去除循环
        # for i in range(bsz):
        #     for j in range(seq_len):
        #         weight_sparse[i, j, idx[i,j]] = weight[i, j, idx[i,j]]
        # weight_sparse /= weight_sparse.sum(dim = -1, keepdim=True) # 归一化
        weight_sparse.scatter_(-1, idx, v)
        weight_sparse /= weight_sparse.sum(dim = -1, keepdim=True).clamp_min(1e-12) # 归一化

        
        print('专家归一化权重:', weight_sparse)
        
        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            output += weight_sparse[...,i].unsqueeze(-1) * expert(x)

        return output

k = 2
smoe = MoESparse(dim=dim, num_experts=num_experts, topk=k)
print(smoe)

MoESparse(
  (experts): ModuleList(
    (0-7): 8 x Expert(
      (w): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (gate): Linear(in_features=512, out_features=8, bias=True)
)


In [10]:
Y = smoe(X)
print(Y.shape)

# top-k 算子是不可微的, backward 成功
loss = Y.mean()
loss.backward()

token选择的专家 id 是不同的: tensor([[[2, 3],
         [4, 3],
         [1, 7]],

        [[0, 5],
         [7, 4],
         [0, 2]]])
专家归一化权重: tensor([[[0.0000, 0.0000, 0.6649, 0.3351, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.4914, 0.5086, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.6973, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3027]],

        [[0.5072, 0.0000, 0.0000, 0.0000, 0.0000, 0.4928, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.4353, 0.0000, 0.0000, 0.5647],
         [0.5771, 0.0000, 0.4229, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
       grad_fn=<DivBackward0>)
torch.Size([2, 3, 512])


上述代码打印了 top-k 专家。 总共 2x3=6 个token，每个 token 所选择的专家 id 是不一样

## SMoE 减少计算量

在上述实现中，虽然top-k筛选了专家，非top-k专家仍参与了计算（没有意义，其权值为0)

为了减少计算量，只计算 top-k 专家的输出

In [8]:
class MoESparseEfficient(MoEBasic):
    def __init__(self, dim=512, num_experts=8, topk = 2):
        super().__init__()
        self.topk = 2

    def forward(self, x):
        bsz, seq_len, _ = x.shape
        weight = self.gate(x)
        weight = F.softmax(weight, dim = -1) # bsz, seq_len, num_experts

        # process top-k index
        v, idx = torch.topk(weight, dim = -1, k=self.topk)
        v /= v.sum(dim = -1, keepdim=True)

        output = torch.zeros_like(x)
        
        # token-wise for-loop 
        for i in range(bsz):
            for j in range(seq_len):
                for k in range(self.topk):
                    # 共调用 bsz*seq_len*topk次专家
                    expert_id = idx[i,j,k]
                    output[i,j,:] += v[i,j,k] * self.experts[expert_id]( x[i,j,:])
                    
        return output

# k = 2
smoe_forloop = MoESparseEfficient(dim=dim, num_experts=num_experts, topk=k)
print(smoe_forloop)

MoESparseEfficient(
  (experts): ModuleList(
    (0-7): 8 x Expert(
      (w): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (gate): Linear(in_features=512, out_features=8, bias=True)
)


In [9]:
Y = smoe_forloop(X)
print(Y.shape)

# top-k 算子是不可微的, backward 成功
loss = Y.mean()
loss.backward()

torch.Size([2, 3, 512])


## SMoE dispatch mode

上述 forward 实现中, token-level forward 计算低效。 

|              | E1      | E2      | E3          | E4        |
| ------------ | ------- | ------- | ----------- | --------- |
| Token 1      | 1       | 0       | 1           | 0         |
| Token 2      | 0       | 0       | 1           | 1         |
| Token 3      | 0       | 1       | 1           | 1         |
| **dispatch** | Token 1 | Token 3 | Token 1,2,3 | Token 2,3 |

从 expert 的视角，如 E4 可以一次性 forward token1, 2,3。

从forloop版本实现思路，变为调用 `num_experts` 次专家 forward

E4 forward 得到 `E4 = [3, dim]`
E2 forward 得到 `E2 = [1, dim]`

对于 token2 的输出为 `weight[1, 1]*E2[0,:] + weight[1, 0]*E4[1,:]` #token2 top-k 专家ID 为 `[E4, E2]`

In [10]:
idx = torch.randint(4, (3,2))
print(idx)
print(torch.where(idx == 0))
print(torch.where(idx == 1))
print(torch.where(idx == 2))
print(torch.where(idx == 3))

# 注意：有的专家未被任何token选择到

tensor([[2, 2],
        [1, 3],
        [0, 0]])
(tensor([2, 2]), tensor([0, 1]))
(tensor([1]), tensor([0]))
(tensor([0, 0]), tensor([0, 1]))
(tensor([1]), tensor([1]))


In [11]:
def is_empty_expert(idx):
    return len(idx.tolist()) == 0
print( is_empty_expert( torch.tensor([]) ))
print( is_empty_expert( torch.tensor([1]) ))

True
False


In [12]:
class MoESparseExpertLoop(MoEBasic):
    def __init__(self, dim=512, num_experts=8, topk = 2):
        super().__init__()
        self.topk = 2

    def forward(self, x):
        bsz, seq_len, dim = x.shape
        N = bsz * seq_len 
        x = x.view(N, dim) #共有 N 个 token

        weight = self.gate(x)
        weight = F.softmax(weight, dim = -1) # bsz, seq_len, num_experts

        # process top-k index
        v, idx = torch.topk(weight, dim = -1, k=self.topk)
        v /= v.sum(dim = -1, keepdim=True)
        print(idx)
        print(v)

        token_to_expert = [None] * self.num_experts
        for i in range(self.num_experts):
            token_id = torch.where(idx == i) # dim 0 is token id
            if is_empty_expert(token_id[0]):
                continue
            token_to_expert[i] = token_id
            
        output = torch.zeros_like(x)
        for i in range(self.num_experts):
            if token_to_expert[i] is None:
                print('expert ', i, ' **empty select token**')
                continue
            cur_token = token_to_expert[i][0]
            cur_weight = v[token_to_expert[i][0], token_to_expert[i][1]] 
            dispatch_x = x[cur_token, :]
            
            print('expert:', i,'select_token:', cur_token, '\t\tcur weight:', cur_weight)
            output[cur_token, :] += cur_weight.unsqueeze(-1) * self.experts[i](dispatch_x)

        return output.reshape(bsz, seq_len, dim), idx, v, token_to_expert

# k = 2
smoe_eloop = MoESparseExpertLoop(dim=dim, num_experts=num_experts, topk=k)
print(smoe_eloop)

MoESparseExpertLoop(
  (experts): ModuleList(
    (0-7): 8 x Expert(
      (w): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (gate): Linear(in_features=512, out_features=8, bias=True)
)


In [13]:
X = torch.randn(bsz, seq_len, dim)
Y, idx, weight, token_to_expert = smoe_eloop(X)
print(Y.shape)

# # top-k 算子是不可微的, backward 成功
loss = Y.mean()
loss.backward()

tensor([[7, 6],
        [7, 5],
        [6, 2],
        [4, 2],
        [3, 1],
        [1, 3]])
tensor([[0.5352, 0.4648],
        [0.6510, 0.3490],
        [0.6306, 0.3694],
        [0.5105, 0.4895],
        [0.6084, 0.3916],
        [0.5695, 0.4305]], grad_fn=<DivBackward0>)
expert  0  **empty select token**
expert: 1 select_token: tensor([4, 5]) 		cur weight: tensor([0.3916, 0.5695], grad_fn=<IndexBackward0>)
expert: 2 select_token: tensor([2, 3]) 		cur weight: tensor([0.3694, 0.4895], grad_fn=<IndexBackward0>)
expert: 3 select_token: tensor([4, 5]) 		cur weight: tensor([0.6084, 0.4305], grad_fn=<IndexBackward0>)
expert: 4 select_token: tensor([3]) 		cur weight: tensor([0.5105], grad_fn=<IndexBackward0>)
expert: 5 select_token: tensor([1]) 		cur weight: tensor([0.3490], grad_fn=<IndexBackward0>)
expert: 6 select_token: tensor([0, 2]) 		cur weight: tensor([0.4648, 0.6306], grad_fn=<IndexBackward0>)
expert: 7 select_token: tensor([0, 1]) 		cur weight: tensor([0.5352, 0.6510], grad_fn=

## SMoE dispatch-combine mode

简化代码实现，实现3段式 smoe

1. dispatch
2. compute
3. combine

In [14]:
# class SMoEOuput:
#     gates = None
#     weight = None
#     v = None
#     idx = None
#     token_to_expert = None

In [15]:
class SparseMixtreOfExpert(MoEBasic):
    def __init__(self, dim=512, num_experts=8, topk = 2):
        super().__init__()
        self.topk = 2

    def forward(self, x):
        bsz, seq_len, dim = x.shape
        N = bsz * seq_len 
        x = x.view(N, dim) #共有 N 个 token

        # 0. gate
        gates = self.gate(x)
        weight = F.softmax(gates, dim = -1) # bsz, seq_len, num_experts
        v, idx = torch.topk(weight, dim = -1, k=self.topk)
        v /= v.sum(dim = -1, keepdim=True)

        # 1. dispatch
        token_to_expert = [None] * self.num_experts
        for i in range(self.num_experts):
            token_id = torch.where(idx == i) # dim 0 is token id
            if is_empty_expert(token_id[0]):
                continue
            token_to_expert[i] = token_id

        # 2. compute
        dispatch_y = [None] * self.num_experts
        for i in range(self.num_experts):
            if token_to_expert[i] is not None:
                cur_token = token_to_expert[i][0]
                dispatch_x = x[cur_token, :]
                dispatch_y[i] = self.experts[i](dispatch_x)

        # 3. combine
        y = torch.zeros_like(x)
        for i in range(self.num_experts):
            if dispatch_y[i] is not None:
                cur_weight = v[token_to_expert[i][0], token_to_expert[i][1]]
                y[token_to_expert[i][0], :] += cur_weight.unsqueeze(dim = -1) * dispatch_y[i]

        # 4. reshape y
        y = y.reshape(bsz, seq_len, dim)

        return y, token_to_expert, v

my_moe = SparseMixtreOfExpert(dim=dim, num_experts=num_experts, topk=k)
print(my_moe)

SparseMixtreOfExpert(
  (experts): ModuleList(
    (0-7): 8 x Expert(
      (w): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (gate): Linear(in_features=512, out_features=8, bias=True)
)


In [16]:
X = torch.randn(bsz, seq_len, dim)
Y, _, _ = my_moe(X)
print(Y.shape)

# # top-k 算子是不可微的, backward 成功
loss = Y.mean()
loss.backward()

torch.Size([2, 3, 512])


## 总结

1. LLM 的 sMoE 可以由多个 FFN 与 1 个门控网络组成，其输出特征是一种集成特征
3. 每个 token 根据门控选择到不同的专家
4. sMoE 是扩展模型参数同时减少计算量行之有效的做法
5. 分析 sMoE 的参数量，激活参数量

讨论：

1. 专家之间输出是否有冗余，sMoE 当前的 top-k 中的 k 是固定的，能否 k 也是一个 adaptive 的，例如 top-p 专家？
2. sMoE 中有 gate， 而 swiglu 中有 GLU， 分析二者的门控差异？
3. 如果将 sMoE 看作是专家特征加权组合，分析此加权组合 与 注意力加权组合有什么差异？
4. 如果 expert 之间差异非常大， 是否 gate 存在小的偏移，就导致 sMoE 最终输出的特征存在巨大变化？
5. 设计一个带 sparse gate 的 attention 组件