# 【手撕LLM-Mixtral】

![mistral](./mistral.png)

In [3]:
!export HF_ENDPOINT=https://hf-mirror.com

## 1. 创建Mixtral 模型

In [16]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import torch.nn.functional as F
seed_value = 42
torch.manual_seed(seed_value)

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

# 创建小模型用于调试
config = AutoConfig.from_pretrained(model_id)
config.num_hidden_layers = 2
config.num_attention_heads = 8
config.hidden_size = 128
config.intermediate_size = 256
config.intermediate_size = config.hidden_size*2
config.num_experts_per_tok = 2 # Top-2 专家数量
config.num_local_experts = 8   # 专家总数量
# print(config)

model =  AutoModelForCausalLM.from_config(config)
print(model)

## 创建MoE层

In [17]:
import torch
from torch import nn
from transformers import MixtralConfig

class MixtralBLockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        self.act_fn = nn.SiLU()

    def forward(self, hidden_states):
        y = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        y = self.w2(y)
        return y

x = torch.randn(1, 64, 128)
expert = MixtralBLockSparseTop2MLP(config)
print('单个专家为原LLaMA的MLP层')
print(expert)
g = expert(x)
print('单个专家输入:', x.shape)
print('单个专家输出结果：', g.shape)


## 创建混合专家

![mixtral-mode](./moe.png)

In [18]:
class MixtralSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        # gating
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

        self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) \
                                      for _ in range(self.num_experts)])

x = torch.randn(1, 64, 128)
experts = MixtralSparseMoeBlock(config)
print('多个专家混合专家')
print(experts)
# print
# gs = experts(x)
# print('多个专家输入:', x.shape)
# print('多个专家混合输出结果：', gs.shape)

## 混合专家 Forward

In [55]:
# 阶段一
# 计算稀疏 gating 值
tokens = 6
x = torch.randn(1, tokens, 128) # 6个token
hidden_states = x
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

 # 每层都会产生router_logits, 将用于最后作 load balance loss
router_logits = experts.gate(hidden_states)
print(f'experts.gate output router logits : \n {router_logits}')

# 计算 TopK 的 专家 logits 和 Top2 专家的位置
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
print(f'softmax weight  : \n {routing_weights}')

routing_weights, selected_experts = torch.topk(routing_weights, \
                                               experts.top_k, dim=-1)
print(f'expert select : \n {selected_experts}')
print(f'topk : \n {routing_weights}')

routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
print(f'topk归一化 : \n {routing_weights}')

routing_weights = routing_weights.to(hidden_states.dtype)

## One Hot 编码
expert_mask = torch.nn.functional.one_hot(selected_experts, \
                                          num_classes=experts.num_experts).permute(2, 1, 0)
for i in range(tokens):
    print(f'【x_{i}】\n', expert_mask[:,:,i]) 

In [56]:
## 最终结果
final_hidden_states = torch.zeros(
    (batch_size * sequence_length, hidden_dim), \
        dtype=hidden_states.dtype, device=hidden_states.device
)
print(f'final moe result shape for each token: {final_hidden_states.shape}')

# 每个专家收集需要计算token
for expert_idx in range(experts.num_experts):

    print(f'----------------- expert {expert_idx} --------------------')

    expert_layer = experts.experts[expert_idx]
    print(expert_mask[expert_idx])
    idx, top_x = torch.where(expert_mask[expert_idx])
    print(f'专家 {expert_idx} 选到的样本编号:',top_x.tolist())  # select x_idx for expert top1
    print(f'专家 {expert_idx} top1:0, top2:1 ',idx.tolist())    # 0 is top1 ,1 is top2
    print(f'有 {len(top_x)} / {x.shape[1]} token 选到专家 {expert_idx}')
    
    top_x_list = top_x.tolist()
    idx_list = idx.tolist()

    current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)

    # expert_0(x) * routing_weights
    current_hidden_states = expert_layer(current_state)  \
                            * routing_weights[top_x_list, idx_list, None]

    # 将计算的单个专家结果填入到结果表里
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

    print(current_state.shape) 
    print(routing_weights[top_x_list, idx_list, None].shape)
    print(current_hidden_states.shape)
    print(final_hidden_states.shape)

    # if expert_idx == 1: break


## load balance loss

![moe-math](./moe_math.png)

In [70]:
import torch

num_experts = 8
batch = 10 
seq_length = 6
top_k = 2

print(f'sMoE num_experts:{num_experts} top_k:{top_k} batch:{batch} seq_length:{seq_length}')

router_logits_1 = torch.randn(batch, seq_length, num_experts).view(-1,num_experts) # layer 1
router_logits_2 = torch.randn(batch, seq_length, num_experts).view(-1,num_experts) # layer 2
router_logits = [router_logits_1, router_logits_2] 

concatenated_gate_logits = torch.cat(router_logits, dim = 0)
print('单层gating的路由logits:', router_logits_1.shape) 
print('两层gating的路由logits:', concatenated_gate_logits.shape)

print('根据logits top-k 计算热独编码')
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
print(expert_mask.shape)

tokens_sum_expert = torch.sum(expert_mask.float(), dim=0)
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
print(f'top1 每个专家平均处理的token   :', tokens_sum_expert[0])
print(f'top2 每个专家平均处理的token fi:', tokens_per_expert[1])
print(f'top1与top2水平合计', tokens_per_expert.sum(dim=1))

# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
print('router_prob_per_expert Pi: ' , router_prob_per_expert)

print( '每个专家的负载：',  tokens_per_expert * router_prob_per_expert.unsqueeze(0))
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
print('final loss:', overall_loss)
