# DeepSeek-V3 

V3 各组件：

1. `lecture/lc5_deepseek_v3/DeepSeek-MoE.ipynb`
2. `lecture/lc5_deepseek_v3/Load_Balance.ipynb`
3. `lecture/lc5_deepseek_v3/Multi_Latent_Attention.ipynb`
4. `lecture/lc5_deepseek_v3/Multi_Token_Prediction.ipynb`
5. `lecture/lc5_deepseek_v3/YaRN.ipynb`

前置理解 Notebook：

1. `lecture/lc5_deepseek_v3/Mixture-of-Experts.ipynb`
2. `lecture/lc5_deepseek_v3/top-k_backward.ipynb`

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x114664db0>

# config

In [2]:
from dataclasses import dataclass

@dataclass
class DeepSeekV3Config:
    learning_rate: float = 0.001
    vocab_size: int = 200
    dim: int = 512
    n_heads: int = 8
    head_dim: int = dim//n_heads
    num_layers: int = 12
    pad_token_id: int = 0
    attention_bias: bool = False # without bias

    # MoE
    expert_nums: int=20
    shared_expert_nums: int=4
    top_k: int = 4

    # MLA
    dc_kv: int = 32
    dc_q: int = 32

    # YaRN
    position_encoding_base: float = 10000.0
    base_scale: float = 10.0
    yarn_alpha: int = 1
    yarn_beta: int = 32
    max_len: int = 512

    # MTP
    num_mtp: int = 5

    # loss

config = DeepSeekV3Config()
print(config)

DeepSeekV3Config(learning_rate=0.001, vocab_size=200, dim=512, n_heads=8, head_dim=64, num_layers=12, pad_token_id=0, attention_bias=False, expert_nums=20, shared_expert_nums=4, top_k=4, dc_kv=32, dc_q=32, position_encoding_base=10000.0, base_scale=10.0, yarn_alpha=1, yarn_beta=32, max_len=512, num_mtp=5)


## RMSNorm

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, dim, epsilon = 0.0001):
        super(RMSNorm, self).__init__()
        self.dim = dim 
        self.epsilon = epsilon
        self.gamma = torch.nn.Parameter(torch.ones(self.dim))
                                        
    def forward(self, x):
        RMS = torch.mean(x ** 2.0 , dim = -1, keepdim = True) 
        x_hat = x / torch.sqrt( RMS + self.epsilon )
        x_out = x_hat * self.gamma 
        return x_out

## YaRN

In [4]:
def _apply_rotary_emb(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> torch.Tensor:
    _, _, seq_len, _ = x.size()
    cos = cos[:seq_len,:]
    sin = sin[:seq_len,:]

    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)

    x1, x2 = torch.chunk(x, 2, dim=-1)
    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin
    return torch.cat((o1, o2), dim=-1)
    
class YaRN(torch.nn.Module):
    def __init__(
        self,
        config,
    ) -> None:
        super().__init__()
        self.head_dim = config.head_dim
        self.base = config.position_encoding_base
        self.initial_context_length = config.max_len
        self.scaling_factor = config.base_scale
        self.ntk_alpha = config.yarn_alpha
        self.ntk_beta = config.yarn_beta

        cos, sin = self._compute_cos_sin(self.initial_context_length)
        self.register_buffer('cos', cos)
        self.register_buffer('sin', sin)

    def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
        """See YaRN paper: https://arxiv.org/abs/2309.00071"""
        freq = self.base ** (
            torch.arange(0, self.head_dim, 2, dtype=torch.float)
            / self.head_dim
        )        
        # self.scaling_factor = max( 1, self.cur_context_length  / self.initial_context_length)
        if self.scaling_factor > 1.0:
            concentration = (
                0.1 * math.log(self.scaling_factor) + 1.0
            )  
            print('t:',  1/concentration**2)

            d_half = self.head_dim / 2
            low = (
                d_half
                * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
                / math.log(self.base)
            )
            high = (
                d_half
                * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
                / math.log(self.base)
            )
            assert 0 < low < high < d_half - 1

            interpolation = 1.0 / (self.scaling_factor * freq)
            extrapolation = 1.0 / freq

            ramp = (
                torch.arange(d_half, dtype=torch.float32, ) - low
            ) / (high - low)
            mask = 1 - ramp.clamp(0, 1)
            inv_freq = interpolation * (1 - mask) + extrapolation * mask
        else:
            concentration = 1.0
            inv_freq = 1.0 / freq

        return concentration, inv_freq

    def _compute_cos_sin(self, num_tokens: int):
        concentration, inv_freq = self._compute_concentration_and_inv_freq()
        t = torch.arange(num_tokens, dtype=torch.float32)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        cos = freqs.cos() * concentration
        sin = freqs.sin() * concentration
        return cos, sin

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        query = _apply_rotary_emb(query, self.cos, self.sin)
        key = _apply_rotary_emb(key, self.cos, self.sin)
        return query, key

bsz = 2
seq_len = 16

Q = torch.randn(bsz, config.n_heads, seq_len, config.head_dim)
K = torch.randn(bsz, config.n_heads, seq_len, config.head_dim)

yarn = YaRN(config)
q_rope, k_rope = yarn.forward(Q, K)
_apply_rotary_emb(Q, yarn.cos, yarn.sin).shape

t: 0.6607044696629862
YaRN Re-Scale: 1.2302585092994045


torch.Size([2, 8, 16, 64])

## MLA

In [72]:
class MLA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.dim = config.dim
        self.head_dim = config.head_dim
        self.dc_kv = config.dc_kv
        self.dc_q = config.dc_q

        # Q
        self.wq_down = nn.Linear(self.dim, self.dc_q, bias=False,)
        self.wq_up = nn.Linear(self.dc_q, self.dim , bias=False,)

        # 单个 C 映射到 K、V
        self.wkv_down = nn.Linear(self.dim, self.dc_kv, bias=False,)
        self.wk_up = nn.Linear(self.dc_kv, self.dim, bias=False,)
        self.wv_up = nn.Linear(self.dc_kv, self.dim, bias=False,)
        
        self.wo = nn.Linear(self.dim, self.dim, bias=False,)

        # RoPE Weight
        # K 每头一样， Q每头不一样
        self.wq_up_rope = nn.Linear(self.dc_q, self.dim, bias=False,)
        self.wk_head_rope = nn.Linear(self.dim, self.head_dim , bias=False,)
        
    def forward(self, X, sin, cos, mask = None):
        bsz, seq_len, _ = X.shape
        C_Q = self.wq_down(X)
        Q = self.wq_up(C_Q)
        C_KV = self.wkv_down(X)
        K = self.wk_up(C_KV)
        V = self.wv_up(C_KV)

        
        R_Q = self.wq_up_rope(C_Q) #多头
        R_K = self.wk_head_rope(X) #单头
        
        R_Q = R_Q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1,2)
        R_K = R_K.unsqueeze(dim = 1).repeat( repeats = [1, self.n_heads, 1, 1])
        RoPE_Q_head = _apply_rotary_emb(R_Q, sin, cos)
        RoPE_K_head = _apply_rotary_emb(R_K, sin, cos)
        
        # 与传统多头注意力无差别
        Q = Q.view(bsz, seq_len, self.n_heads, self.head_dim)
        K = K.view(bsz, seq_len, self.n_heads, self.head_dim)
        V = V.view(bsz, seq_len, self.n_heads, self.head_dim)
        Q, K, V = Q.transpose(1,2), K.transpose(1,2), V.transpose(1,2)
        

        Q = torch.cat((Q, RoPE_Q_head), dim = -1)
        K = torch.cat((K, RoPE_K_head), dim = -1)

        S = Q @ K.transpose(2, 3) / math.sqrt(2 * self.head_dim) # cat后dim维度变化，底数也有变化
        P = F.softmax(S.float(), dim=-1)
        Z = P @ V 
        Z = Z.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
        O = self.wo(Z)
        return O

mla = MLA(config)
print(mla)

MLA(
  (wq_down): Linear(in_features=512, out_features=32, bias=False)
  (wq_up): Linear(in_features=32, out_features=512, bias=False)
  (wkv_down): Linear(in_features=512, out_features=32, bias=False)
  (wk_up): Linear(in_features=32, out_features=512, bias=False)
  (wv_up): Linear(in_features=32, out_features=512, bias=False)
  (wo): Linear(in_features=512, out_features=512, bias=False)
  (wq_up_rope): Linear(in_features=32, out_features=512, bias=False)
  (wk_head_rope): Linear(in_features=512, out_features=64, bias=False)
)


In [73]:
X = torch.randn(bsz, seq_len, config.dim)
X_ = mla(X, mask = None, sin = yarn.sin, cos = yarn.cos)
print(X_.shape)

torch.Size([2, 16, 512])


## MoE

In [74]:
class Expert(nn.Module): # expert
    def __init__(self, dim):
        super().__init__()
        self.dim_in = dim
        self.dim_out = self.dim_in * 8 // 3
        self.w1 = nn.Linear(self.dim_in, self.dim_out , bias = False)
        self.w_act = nn.Linear(self.dim_in, self.dim_out, bias = False) 
        self.w2 = nn.Linear(self.dim_out, self.dim_in, bias = False)  
        self.SiLU = nn.SiLU()
    
    def forward(self, x):
        h = self.w1(x)
        h_act = self.w_act(x)
        h_act_up = self.SiLU(h_act) * h
        output = self.w2(h_act_up)
        return output

In [75]:
class DeepSeekV3MoE(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Route Experts: 
        self.expert_nums = config.expert_nums
        self.k = config.top_k
        self.dim = config.dim
        self.experts = nn.ModuleList([ Expert(self.dim) for _ in range(self.expert_nums)])
        self.w_gate = nn.Linear(self.dim, self.expert_nums)

        # Auxiliary-Loss-Free Load Balancing
        self.bias = torch.nn.Parameter(torch.zeros( self.expert_nums )) 
        self.alpha = 0.001
        
        # Shared Experts: 
        self.shared_expert_nums = config.shared_expert_nums
        self.shared_experts = nn.ModuleList([ Expert(self.dim) for _ in range(self.shared_expert_nums)])

    def forward(self, x):
        y_route, weight, idx = self.forward_route(x) 
        y_shared = self.forward_shared(x) 
        y = x + y_route + y_shared
        # load_loss = self.load_balance_sequence_wise(weight, idx)
        # load_loss = 1.0
        return y, weight, idx
    
    def forward_route(self, x):
        # gate 处理
        g = self.w_gate(x)
        g = F.sigmoid(g) # sigmoid 代替 softmax
        weight, idx = torch.topk(g, k = self.k, dim = -1) 
        weight_norm = weight / (weight.sum(dim=-1, keepdim= True) + 1e-20)

        # dispatch
        expert_results = [None] * self.expert_nums
        for i in range(self.expert_nums):
            cur_pos = torch.where(idx == i) 
            x_select = x[cur_pos[0], cur_pos[1], :] 
            if x_select.shape[0] > 0: 
                expert_results[i] = self.experts[i](x_select)
        
        # combine
        y_result = torch.zeros_like(x) 
        for i in range(self.expert_nums):
            cur_pos = torch.where(idx == i) 
            if expert_results[i] != None:
                y_result[cur_pos[0], cur_pos[1], :] += expert_results[i] * weight_norm[cur_pos[0], cur_pos[1], cur_pos[2]].unsqueeze(-1)
        
        return y_result, g, idx 

    def forward_shared(self, x):
        y = torch.zeros_like(x)
        for i in range(self.shared_expert_nums):
            y += self.shared_experts[i](x) # not gate
        return y

# model
model = DeepSeekV3MoE(config)
print(model)

DeepSeekV3MoE(
  (experts): ModuleList(
    (0-19): 20 x Expert(
      (w1): Linear(in_features=512, out_features=1365, bias=False)
      (w_act): Linear(in_features=512, out_features=1365, bias=False)
      (w2): Linear(in_features=1365, out_features=512, bias=False)
      (SiLU): SiLU()
    )
  )
  (w_gate): Linear(in_features=512, out_features=20, bias=True)
  (shared_experts): ModuleList(
    (0-3): 4 x Expert(
      (w1): Linear(in_features=512, out_features=1365, bias=False)
      (w_act): Linear(in_features=512, out_features=1365, bias=False)
      (w2): Linear(in_features=1365, out_features=512, bias=False)
      (SiLU): SiLU()
    )
  )
)


In [76]:
Y, _, _ = model(X)
print(Y.shape)

torch.Size([2, 16, 512])


## DeepSeek-V3 Decode Block

In [77]:
class DeepSeekV3Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Norm1 = RMSNorm(config.dim)
        self.Norm2 = RMSNorm(config.dim)
        self.MLA = MLA(config)
        self.MoE = DeepSeekV3MoE(config)

    def forward(self, X, mask=None, sin=None, cos=None):
        X = X + self.MLA(self.Norm1(X), sin, cos)
        X_moe, weight, idx = self.MoE(self.Norm2(X))
        X = X + X_moe
        return X, weight, idx

block = DeepSeekV3Block(config)
print(block)

DeepSeekV3Block(
  (Norm1): RMSNorm()
  (Norm2): RMSNorm()
  (MLA): MLA(
    (wq_down): Linear(in_features=512, out_features=32, bias=False)
    (wq_up): Linear(in_features=32, out_features=512, bias=False)
    (wkv_down): Linear(in_features=512, out_features=32, bias=False)
    (wk_up): Linear(in_features=32, out_features=512, bias=False)
    (wv_up): Linear(in_features=32, out_features=512, bias=False)
    (wo): Linear(in_features=512, out_features=512, bias=False)
    (wq_up_rope): Linear(in_features=32, out_features=512, bias=False)
    (wk_head_rope): Linear(in_features=512, out_features=64, bias=False)
  )
  (MoE): DeepSeekV3MoE(
    (experts): ModuleList(
      (0-19): 20 x Expert(
        (w1): Linear(in_features=512, out_features=1365, bias=False)
        (w_act): Linear(in_features=512, out_features=1365, bias=False)
        (w2): Linear(in_features=1365, out_features=512, bias=False)
        (SiLU): SiLU()
      )
    )
    (w_gate): Linear(in_features=512, out_features=20, 

In [78]:
Y,_,_ = block(X, sin=yarn.sin, cos=yarn.cos)
Y.shape

torch.Size([2, 16, 512])

## DeepSeek-V3

In [79]:
class DeepSeekV3Model(nn.Module):
    def __init__(self, config):
        '''
        lm_head 为主体输出头
        num_mtp 额外增加 MTP 输出头模块数量
        总输出头模块：1+num_mtp
        '''
        super().__init__()
        self.dim = config.dim
        self.vocab_size = config.vocab_size
        
        self.embd = nn.Embedding(self.vocab_size, self.dim)

        self.yarn = YaRN(config)

        self.decoder = nn.ModuleList(
            [ DeepSeekV3Block(config) for _ in range(config.num_layers) ]
        )
        # self.mtp_body =  MTPModule(config)
        
        self.lm_head = nn.Linear(self.dim, self.vocab_size)
        

    def forward(self, x, mask):
        X = self.embd(x)  
        moe_weight_list = []
        moe_idx_list = []
        
        for decoder_block in self.decoder:
            X, weight, idx = decoder_block(X, mask, yarn.sin, yarn.cos)
            moe_weight_list.append(weight)
            moe_idx_list.append(idx)
            
        logits = self.lm_head(X)
        return logits, moe_weight_list, moe_idx_list
        
model = DeepSeekV3Model(config)
print(model)

t: 0.6607044696629862
YaRN Re-Scale: 1.2302585092994045
DeepSeekV3Model(
  (embd): Embedding(200, 512)
  (yarn): YaRN()
  (decoder): ModuleList(
    (0-11): 12 x DeepSeekV3Block(
      (Norm1): RMSNorm()
      (Norm2): RMSNorm()
      (MLA): MLA(
        (wq_down): Linear(in_features=512, out_features=32, bias=False)
        (wq_up): Linear(in_features=32, out_features=512, bias=False)
        (wkv_down): Linear(in_features=512, out_features=32, bias=False)
        (wk_up): Linear(in_features=32, out_features=512, bias=False)
        (wv_up): Linear(in_features=32, out_features=512, bias=False)
        (wo): Linear(in_features=512, out_features=512, bias=False)
        (wq_up_rope): Linear(in_features=32, out_features=512, bias=False)
        (wk_head_rope): Linear(in_features=512, out_features=64, bias=False)
      )
      (MoE): DeepSeekV3MoE(
        (experts): ModuleList(
          (0-19): 20 x Expert(
            (w1): Linear(in_features=512, out_features=1365, bias=False)
       

In [80]:
input_ids = torch.randint(config.vocab_size, (bsz, seq_len))
logits, _, _ = model(input_ids, mask = None)
print(input_ids.shape)
print(logits.shape)

torch.Size([2, 16])
torch.Size([2, 16, 200])


## MTP

1. 如果考虑 位置编码、mask 的话， MTP 实现稍微麻烦些。
2. 当前以一个 Linear 来替换，便于实现。后续优化

In [94]:
class MTPModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.dim
        self.RMSNorm_pre = RMSNorm(self.dim)
        self.RMSNorm_cur = RMSNorm(self.dim)
        self.Proj = nn.Linear(self.dim*2, self.dim)
        
        # TODO: use decode block
        # self.Transformer_block = DeepSeekV3Block(config)
        self.Transformer_block = nn.Linear(self.dim, self.dim)
        
    def forward(self, X_embd, H_pre,):
        X_embd, H_pre = self.RMSNorm_cur(X_embd), self.RMSNorm_pre(H_pre)
        X = torch.cat((X_embd, H_pre), dim = -1)
        X = self.Proj(X)
        X = self.Transformer_block(X) # 需要传参 mask, rope
        return X

class DeepSeekV3ModelMTP(nn.Module):
    def __init__(self, config):
        '''
        lm_head 为主体输出头
        num_mtp 额外增加 MTP 输出头模块数量
        总输出头模块：1+num_mtp
        '''
        super().__init__()
        self.dim = config.dim
        self.vocab_size = config.vocab_size
        self.num_mtp = config.num_mtp
        self.embd = nn.Embedding(self.vocab_size, self.dim)
        self.yarn = YaRN(config)
        self.decoder = nn.ModuleList(
            [ DeepSeekV3Block(config) for _ in range(config.num_layers) ]
        )
        self.lm_head = nn.Linear(self.dim, self.vocab_size)

        # 新增 MTP 头 
        self.mtp_heads = nn.ModuleList(
            [ MTPModule(config) for _ in range(self.num_mtp) ]
        )

    def forward(self, x, mask=None):
        # main: input
        bsz, seq_len = x.shape
        X = self.embd(x)
        X = X[:, :-self.num_mtp, :]

        # main: decoder
        moe_weight_list = []
        moe_idx_list = []
        for decoder_block in self.decoder:
            X, weight, idx = decoder_block(X,
                                           mask, 
                                           yarn.sin, 
                                           yarn.cos)
            moe_weight_list.append(weight)
            moe_idx_list.append(idx)
        # main: output
        lm_logits = self.lm_head(X) # bsz, len, vocab_size
        hidden_states = X.clone()

        # MTP: 
        X = self.embd(x)
        mtp_logits = torch.zeros(self.num_mtp, 
                                 bsz, 
                                 seq_len - self.num_mtp, 
                                 self.vocab_size)
        # MTP Recurrent
        for i in range(self.num_mtp):
            # MTP: input
            X_cur = X[:, i+1: i+1+seq_len-self.num_mtp, :]

            # MTP: body
            hidden_states_i = self.mtp_heads[i](X_cur, hidden_states.detach())

            # MTP: head
            mtp_logits_i = self.lm_head(hidden_states_i)

            # MTP: update
            mtp_logits[i] = mtp_logits_i
            hidden_states = hidden_states_i

        return lm_logits, moe_weight_list, moe_idx_list, mtp_logits
        
model = DeepSeekV3ModelMTP(config)
# print(model)

t: 0.6607044696629862
YaRN Re-Scale: 1.2302585092994045


In [95]:
print(input_ids.shape)
lm_logits, moe_weight_list, moe_idx_list, mtp_logits = model(input_ids)
print(input_ids.shape)
print(lm_logits.shape)
print(mtp_logits.shape)

torch.Size([2, 16])
torch.Size([2, 16])
torch.Size([2, 11, 200])
torch.Size([5, 2, 11, 200])


## Load Balance loss

In [98]:
def load_balance_sequence_wise(config, s, idx):
    Nr = config.expert_nums # n routes expert
    bs, seq_len, dim = s.shape # seq : pad token 要去除，或者增加 mask, 避免计入 loss 里
    l_lab = torch.zeros(1)
    for k in range(bs):

        # Compute fi
        fi = torch.zeros(Nr)
        seq_expert_count = torch.zeros(Nr)
        idx_seq = idx[k,:,:]
        for i in range(config.expert_nums):
            seq_expert_count[i] = torch.where(idx_seq == i)[1].numel()
        fi = Nr / (config.top_k * seq_len) * seq_expert_count

        # Compute pi
        s_seq = s[k, :, :]
        si_ = s / s.sum(dim = -1, keepdim = True)
        pi = si_.sum(dim = 0) / seq_len
        l_bal_seq = (fi * pi).sum() / seq_len # seq_len_no_pad or use mask
        l_lab += l_bal_seq
        
    l_lab = 0.001 * l_lab
    return l_lab

load_balance_loss = load_balance_sequence_wise(config, moe_weight_list[0], moe_idx_list[0])
print(load_balance_loss)

tensor([0.0004], grad_fn=<MulBackward0>)


In [101]:
IGNORE_INDEX = -100
def DeepSeekV3LMLoss(lm_logits,
                     mtp_logits, 
                     y,
                     lam=0.1):

    N, bsz, seq_len, vocab_size = mtp_logits.shape
    loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
    loss_lm_head = loss_fn(lm_logits.view(bsz*seq_len,vocab_size), 
                           y[:,:-N].reshape(bsz*seq_len))

    if mtp_logits is None:
        loss_mtp_head = None
    else:
        loss_mtp_head = torch.zeros(N)
        for i in range(N):
            loss_mtp_head[i] = loss_fn(mtp_logits[i,:,:,:].view(bsz*seq_len,vocab_size),
                                       y[:,i:-N+i].reshape(bsz*seq_len))
                                
    loss = loss_lm_head + lam*loss_mtp_head.mean()
    return loss, loss_lm_head, loss_mtp_head

labels = torch.zeros_like(input_ids)
labels[:, :-1] = input_ids[:, 1:]
labels[:, -1] = IGNORE_INDEX

loss, loss_lm, loss_mtp = DeepSeekV3LMLoss(lm_logits, 
                     mtp_logits, 
                     y=labels,
                     lam = 0.1)
print(loss)
print(loss_lm)
print(loss_mtp)                                        

tensor(23.5676, grad_fn=<AddBackward0>)
tensor(23.0371, grad_fn=<NllLossBackward0>)
tensor([5.3156, 5.3155, 5.2962, 5.3688, 5.2321], grad_fn=<CopySlices>)


## Training Loss

In [103]:
def DeepSeekV3Loss(
    config,
    lm_logits,
     mtp_logits, 
     y,
     lam=0.1,
    weight_list=None,
    idx_list=None,
    lam_len=0.1,
):
    loss, loss_lm, loss_mtp = DeepSeekV3LMLoss(lm_logits, mtp_logits, y=labels, lam = 0.1)
    loss_load = torch.zeros(config.num_layers)
    i = 0
    for s, idx in zip(weight_list, idx_list):
        loss_load[i] = load_balance_sequence_wise(config, s, idx)
        i = i+1
    loss_total = loss + lam_len * loss_load.mean()
    return loss_total

## Train

## MLA Absorb

## Inference