# sMoE 负载均衡


训练 sMoE 遇到新的问题：

1. 当 gate 网络选择 sparse expert，非 top-k 专家梯度为 0，在训练过程中，专家分布不均衡，导致 MoE 并未达到设计目的
2. top-k 无法求导（上述实现时 torch 的 top-k 内部采用了一种离散的求导机制）

问题1中，可以设计一种 load balance（负载均衡）机制，来保证 sMoE 中各专家都能够充分的参与训练。

1. 浅显的均衡可以理解，每个专家所处理的 token 数量是均匀的，设 token 数量为 20, top-k 为 2, 对于 8 个 expert, 每个专家平均处理 `20*2/8=5` 个 token
2. 需要定义统计量来分析稀疏专家的均衡性

在训练过程中，可以定义出一个负载均衡 loss 项，来保证均衡性。

## Sparse Gate

最简带负载均衡的 Sparse Gate 实现思路：

1. 统计各专家分配的 token 数量

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x118ab8db0>

In [2]:
class SparseGate(nn.Module):
    def __init__(self, dim=512, num_experts=8, topk=2):
        super().__init__()
        self.dim = dim
        self.num_experts = num_experts
        self.topk = topk
        self.gate = nn.Linear(self.dim, self.num_experts)
        
    def forward(self, x):
        bsz, seq_len, dim = x.shape
        x=x.view(bsz*seq_len, dim)

        # gate topk
        gates = self.gate(x)
        weight = F.softmax(gates, dim = -1)
        v, idx = torch.topk(weight, dim = -1, k=self.topk)
        v /= v.sum(dim = -1, keepdim=True)

        # counts
        counts = [0] * self.num_experts
        for i in range(self.num_experts):
            counts[i] = len(torch.where(idx == i)[0].tolist())
            
        return counts, weight, idx, v
dim=512
num_experts=8
k=2
sGate = SparseGate(dim=dim, num_experts=num_experts, topk=k)

In [3]:
bsz=2
seq_len=10
n_tokens=bsz*seq_len

x=torch.randn(bsz, seq_len, dim)

In [4]:
counts, weight, idx, v = sGate(x)
print(counts)

[3, 8, 7, 4, 8, 1, 3, 6]


In [5]:
def loss_balance_basic(counts):
    c = torch.tensor(counts, dtype=torch.float)
    loss = ((c-c.mean())**2).mean()
    return loss
    
loss = loss_balance_basic(counts)
print(loss)

# 定义极不平衡的 expert 处理的 token 数量分布, loss 非常大
dummy_counts = [1,1,1,16,1,1,1,2]
print(loss_balance_basic(dummy_counts))

tensor(6.)
tensor(24.2500)


## load balance 指标

上述计算规则较简单，仅考虑统计数量。回到混合专家输出，其最终输出还与 sparse gate weight $g_i$ 有关

$$
o = \sum_i^N g_i E_i(x)
$$

例如, top-2 的权重分别为 `(0.9,0.1)` 就不如 `(0.45,0.55)`. 

综上考虑负载均衡：

1. 每个 expert 处理的 token 数量平衡
2. 每个 expert 分配的 weight 平衡

In [6]:
x=torch.randn(1, 5, dim)
counts, weight, idx, v =sGate(x)

In [7]:
def sparse_to_matrix(idx, weight, n_experts=8):
    N, k = idx.shape
    mat_idx = torch.zeros(N,n_experts)
    mat_weight = torch.zeros(N,n_experts)
    for i in range(N):
        for j in range(k):
            mat_idx[i, idx[i, j]] = 1
            mat_weight[i, idx[i, j]] = weight[i,j]
    return mat_idx, mat_weight

idx_mat, weight_mat = sparse_to_matrix(idx, v)
print(idx_mat.to(torch.long))
print(weight_mat)

tensor([[0, 0, 1, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 1],
        [1, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 1, 0, 0, 0]])
tensor([[0.0000, 0.0000, 0.4731, 0.0000, 0.5269, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3119, 0.0000, 0.6881],
        [0.5515, 0.0000, 0.4485, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4695, 0.0000, 0.5305, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4544, 0.0000, 0.5456, 0.0000, 0.0000, 0.0000]],
       grad_fn=<CopySlices>)


根据 

> sMoE Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer

其负载均衡损失为：

\begin{equation}
Importance(X) = \sum_{x \in X}G(x)
\end{equation}

\begin{equation}
L_{importance}(X) = w_{importance} \cdot CV(Importance(X))^2
\end{equation}

定义： 专家重要性指标 importance, 如果某个专家被频繁选中或权重更大，其对 MoE 特征贡献越大

Coefficient of Variation（变异系数）。它是衡量数据相对离散程度的标准化统计量。

CV 为标准差相对于平均值有多大:

$$
cv = \frac{\mu}{\sigma}
$$

In [8]:
def importance(x):
    return x.sum(dim=0)
    
print(importance(idx_mat))
print(importance(weight_mat))

tensor([1., 1., 3., 1., 2., 1., 0., 1.])
tensor([0.5515, 0.4695, 1.3760, 0.5305, 1.0725, 0.3119, 0.0000, 0.6881],
       grad_fn=<SumBackward1>)


In [9]:
def CoefficientVariation(x):
    return x.std()/x.mean()

imp = importance(weight_mat)
cv = CoefficientVariation(imp)
# print(cv)
print(cv**2) # loss

tensor(0.4737, grad_fn=<PowBackward0>)


In [16]:
imp1 = torch.tensor([3,0.7,0, 0.1])
imp2 = torch.tensor([1.1,1,1, 0.9])
print(CoefficientVariation(imp1))
print(CoefficientVariation(imp2))

tensor(1.4749)
tensor(0.0816)


In [10]:
def load_balance_loss_smoe(idx, weight, n_experts):
    idx_mat, weight_mat = sparse_to_matrix(idx, weight, n_experts)
    imp = importance(weight_mat)
    cv = CoefficientVariation(imp)
    return cv**2

x=torch.randn(1, 2, dim)
_, _, idx, v =sGate(x)
loss = load_balance_loss_smoe(idx, v, num_experts)
print(loss)

x=torch.randn(1, 10, dim) 
_, _, idx, v =sGate(x)
loss = load_balance_loss_smoe(idx, v, num_experts)
print(loss)

tensor(1.2271, grad_fn=<PowBackward0>)
tensor(0.4051, grad_fn=<PowBackward0>)


由于模型参数是随机化的

bsz * seq_len 越大, 每个专家分配越均匀，loss越小

## 定量分析

In [11]:
N = 100
num_experts = 8


gates = torch.randn(N, num_experts)

def gate_function(gates, k, num_experts):
    weight = F.softmax(gates, dim = -1)
    v, idx = torch.topk(weight, dim = -1, k=k)
    v /= v.sum(dim = -1, keepdim=True)
    return idx, v

gates = torch.randn(N, num_experts)
idx, v = gate_function(gates, k, num_experts)
loss = load_balance_loss_smoe(idx, v, num_experts)
# print(idx)
# print(v)
print(loss)

tensor(0.0965)


In [12]:
gates10 = gates.clone()
gates10[:, 3] = gates[:, 3] * 10 # 专家 3 门控放大
idx, v = gate_function(gates10, k, num_experts)
loss = load_balance_loss_smoe(idx, v, num_experts)
print(loss)

gates100 = gates.clone()
gates100[:, 3] = gates[:, 3] * 100 # 专家 3 门控放大
idx, v = gate_function(gates100, k, num_experts)
loss = load_balance_loss_smoe(idx, v, num_experts)
print(loss)

tensor(0.7825)
tensor(1.1264)


## Switch Transformer load balance

当前流行 loss 版本来自 Switch Transformer (Noam Shazeer), 其负载均衡综合考虑各专家处理token数量和权重

> Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity,

$$
\text{loss} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i
$$

其中, $N$ 为专家数量

$$
f_i = \frac{1}{T} \sum_{x \in \mathcal{B}} 1\{\text{argmax } p(x) = i\}
$$
    
其中, $T$ 为 token 数量, $\mathcal{B}$ 为批次上所有 token, $\text{argmax} p(x) = i$ 表示token $x$ 权重最大对应 $i$ 专家, 由于 Switch Transfomrer，选top-1专家，通用写法可以为 $\text{top-k} p(x) = i$, $f_i$ 表示专家 $i$ 处理的 tokens 的比例

$$
P_i = \frac{1}{T} \sum_{x \in \mathcal{B}} p_i(x)
$$

其中，$P_i$ 表示专家 $i$ 处理的累积的权重和（重要性）

In [13]:
def load_balance_loss_switch(idx, weight, n_experts):
    N, k = idx.shape
    
    idx_mat, weight_mat = sparse_to_matrix(idx, weight, n_experts)

    fi = idx_mat.mean(dim = 0)
    pi = weight_mat.mean(dim = 0)

    return n_experts * (fi * pi).sum()
    
idx, v = gate_function(gates, k, num_experts)
loss = load_balance_loss_switch(idx, v, num_experts)
print(loss)

tensor(2.1702)


In [14]:
gates10 = gates.clone()
gates10[:, 3] = gates[:, 3] * 10 # 专家 3 门控放大
idx, v = gate_function(gates10, k, num_experts)
loss = load_balance_loss_switch(idx, v, num_experts)
print(loss)

gates100 = gates.clone()
gates100[:, 3] = gates[:, 3] * 100 # 专家 3 门控放大
idx, v = gate_function(gates100, k, num_experts)
loss = load_balance_loss_switch(idx, v, num_experts)
print(loss)

tensor(2.5607)
tensor(2.7190)


## 讨论

1. 负载均衡对 sMoE 系统至关重要，在训练中, 能够避免陷入单一专家计算
2. MoE 训练系统中，有一种并行技术称之为专家并行，即每个设备加载其中 1 个专家，负载均衡目标使得各设备计算-通信都是平衡的
3. switch transformer 版本可以理解为： 带派发比例的权重重要性指标

另外讨论两种 case：

1. 在 transformer 中，有多层block，每层block都有一个 sMoE 组件，此时负载均衡是根据 block-wise 算，还是 model-wise 计算？
2. 本 notebook 分析的负载均衡是基于 batch-wise 计算的， 分析 llm 场景采用 sequence-level 计算的合理性

思考 load_balance_loss 的反向传播如何计算？

## Reference

Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer
