Top-K 求导问题

[MoE训练中的Top-K运算不会导致不可导（不连续）吗？](https://www.zhihu.com/question/11071292653/answer/1913934460161852591)

结论：

- top-k 数学上是不可导的
- torch 实现的 top-k 可导，其导数反向传播类似 embedding 层

在 sMoE 论文中， APPENDICES.A 将 top-k 转化为一个连续的概率分布

> Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer

在主流的 sMoE 实现中（如 Mixtral8x7B）， 都用 torch 自带的 top-k 算子。

注意不要手动用 `torch.max` 去取 top 元素，可能导致无法反传梯度

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)

# torch autograd
dim = 16
experts_num = 8
top_k = 2
bs = 2
label = torch.randn(bs, top_k)
a = torch.randn(bs, dim, requires_grad=True)
w_gate = nn.Linear(dim, experts_num)
y = w_gate(a)
y.retain_grad()
p = F.softmax(y, dim = -1)
p.retain_grad()
v, idx = torch.topk(p, k = top_k, dim = -1)
print(f'topk idx: {idx}')
loss = ((v - label) ** 2 ) / label.numel()
loss = loss.sum()
loss.backward()
print('pytorch top-k p grad', p.grad)
print('pytorch input_gradient: ' , a.grad[0,:4])
print('pytorch w_gate: ' , w_gate.weight.grad.t()[0,:4])

print('-' * 50)
# hand-write
dv = 2 * (v - label) / label.numel()
dp = torch.zeros(bs, experts_num)
# only backward select-top-k element's gradient
for i in range(bs):
    dp[i, idx[i]] = dv[i,:]
print('hand-write top-k p grad', dp)

dy = torch.zeros_like(y)
for i in range(bs):
    tmp_p = torch.zeros(experts_num)
    # tmp_p[idx[i]] = p[i, idx[i]] 
    tmp_p = p[i,:]
    d_s = torch.diag(tmp_p) - torch.outer(tmp_p, tmp_p)
    dy[i, :] = dp[i, :] @ d_s  
print(dy)

d_a = dy @ w_gate.weight
d_w_gate = (a.t() @ dy)
print('hand-write input_gradient: ' , d_a[0,:4])
print('hand-write w_gate: ' , d_w_gate[0,:4])

topk idx: tensor([[2, 1],
        [5, 3]])
pytorch top-k p grad tensor([[ 0.0000,  0.0044,  0.0815,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.0345,  0.0000, -0.0008,  0.0000,  0.0000]])
pytorch input_gradient:  tensor([ 0.0054, -0.0018, -0.0061, -0.0017])
pytorch w_gate:  tensor([-0.0020, -0.0089,  0.0362, -0.0038])
--------------------------------------------------
hand-write top-k p grad tensor([[ 0.0000,  0.0044,  0.0815,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.0345,  0.0000, -0.0008,  0.0000,  0.0000]],
       grad_fn=<CopySlices>)
tensor([[-0.0012, -0.0051,  0.0201, -0.0012, -0.0034, -0.0030, -0.0043, -0.0019],
        [ 0.0005,  0.0006,  0.0008, -0.0046,  0.0007,  0.0011,  0.0003,  0.0006]],
       grad_fn=<CopySlices>)
hand-write input_gradient:  tensor([ 0.0054, -0.0018, -0.0061, -0.0017], grad_fn=<SliceBackward0>)
hand-write w_gate:  tensor([-0.0020, -0.0089,  0.0362, -0.0038], grad_fn=<Sli

## Appendix 

> Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer

对于单个 token， 

分位数 x = 专家 i 的 gate 分数 - 去除专家 i 后的 top-k 分数

在标准正态分布的累积分布函数 $p = \phi(x)$ 取概率（而非通过 top-k + softmax)