In [3]:
import json
import pandas as pd
with open('/home/user/wyf/train_moe_from_scratch/20000step_trainer_state.json' , 'r', encoding='utf-8') as f:
    data = json.load(f)
    log_history = data['log_history']
    steps = []
    losses = []
    for i in log_history[:-1]:
        step = i['step']
        loss = i['loss']
        steps.append(step)
        losses.append(loss)

import seaborn as sns
import plotly.express as px

# 创建一个DataFrame
df = pd.DataFrame({
    'Step': steps,
    'Loss': losses
})

# 使用Plotly绘制散点图
fig = px.line(df, x='Step', y='Loss', title='MoE预训练损失变化',
                 labels={'Step': 'step', 'Loss': 'loss'},
                 )  # 添加趋势线可选

fig.update_layout(width=600, height=400)
# 显示图表
fig.show()

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
x = torch.randn(2, 4, 10)

In [6]:
gate = nn.Linear(10, 4)

In [7]:
logits = gate(x)
logits.shape

torch.Size([2, 4, 4])

In [8]:
logits

tensor([[[-0.4174, -0.8181,  0.2931, -1.0983],
         [ 0.4230,  0.5691,  0.7379, -0.3439],
         [-1.4290, -0.6250, -0.0913,  0.8918],
         [ 0.4221,  0.4372, -0.2936, -0.2649]],

        [[ 0.4224, -0.3879, -0.1051,  0.6133],
         [-0.6835, -0.4645,  0.2799,  0.1902],
         [ 0.0440,  0.6685, -0.5246,  0.0843],
         [ 0.3311, -0.2890, -1.1392,  0.5881]]], grad_fn=<ViewBackward0>)

In [10]:
logits_topk, indices = logits.topk(2, dim=-1)
logits_topk, indices

(tensor([[[ 0.2931, -0.4174],
          [ 0.7379,  0.5691],
          [ 0.8918, -0.0913],
          [ 0.4372,  0.4221]],
 
         [[ 0.6133,  0.4224],
          [ 0.2799,  0.1902],
          [ 0.6685,  0.0843],
          [ 0.5881,  0.3311]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 0],
          [2, 1],
          [3, 2],
          [1, 0]],
 
         [[3, 0],
          [2, 3],
          [1, 3],
          [3, 0]]]))

In [11]:
zeros = torch.full_like(logits, float("-inf"))

In [12]:
sparse_logits = zeros.scatter(dim=-1, index=indices, src=logits_topk)
sparse_logits

tensor([[[-0.4174,    -inf,  0.2931,    -inf],
         [   -inf,  0.5691,  0.7379,    -inf],
         [   -inf,    -inf, -0.0913,  0.8918],
         [ 0.4221,  0.4372,    -inf,    -inf]],

        [[ 0.4224,    -inf,    -inf,  0.6133],
         [   -inf,    -inf,  0.2799,  0.1902],
         [   -inf,  0.6685,    -inf,  0.0843],
         [ 0.3311,    -inf,    -inf,  0.5881]]], grad_fn=<ScatterBackward0>)

In [13]:
sparse_logits = F.softmax(sparse_logits, dim=-1)
sparse_logits

tensor([[[0.3295, 0.0000, 0.6705, 0.0000],
         [0.0000, 0.4579, 0.5421, 0.0000],
         [0.0000, 0.0000, 0.2723, 0.7277],
         [0.4962, 0.5038, 0.0000, 0.0000]],

        [[0.4524, 0.0000, 0.0000, 0.5476],
         [0.0000, 0.0000, 0.5224, 0.4776],
         [0.0000, 0.6420, 0.0000, 0.3580],
         [0.4361, 0.0000, 0.0000, 0.5639]]], grad_fn=<SoftmaxBackward0>)

In [14]:
final_outputs = torch.zeros_like(x)
final_outputs

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [16]:
x_flat = x.view(-1, x.shape[-1])
x_flat, x_flat.shape

(tensor([[ 0.8376,  1.6742,  1.5400, -2.3451,  0.9166, -0.6159,  1.2018,  1.5354,
          -0.1724,  0.3897],
         [-0.4154, -0.7223, -0.9576,  0.2360,  0.5441,  0.0049, -0.1311, -0.6673,
          -1.4495,  2.5249],
         [ 1.5304,  0.3498,  0.5821, -0.2789, -0.2131,  0.7828, -0.7165,  0.8480,
           3.1155,  0.9457],
         [ 0.7512, -0.5504, -0.3900,  0.9416,  0.2048, -0.8649,  0.7229,  0.0534,
          -0.2791, -0.3948],
         [-0.1167,  1.1454,  0.4875,  0.6351, -1.6597, -0.2928,  0.0223, -0.8946,
           0.8921,  1.7179],
         [ 0.1254, -0.5159,  1.1645, -1.1828,  0.8253, -0.5996, -0.9029,  0.6295,
           0.9861,  1.1496],
         [-0.1167,  0.0729, -0.7747,  1.1776,  0.7829,  0.6341,  0.5290, -0.1743,
           0.5624,  1.1735],
         [-2.2027,  0.9008,  1.6723,  1.0737,  0.5271,  0.8475, -2.7148, -2.8659,
           0.0842,  1.7391]]),
 torch.Size([8, 10]))

In [17]:
sparse_logits_flat = sparse_logits.view(-1, sparse_logits.shape[-1])
sparse_logits_flat, sparse_logits_flat.shape

(tensor([[0.3295, 0.0000, 0.6705, 0.0000],
         [0.0000, 0.4579, 0.5421, 0.0000],
         [0.0000, 0.0000, 0.2723, 0.7277],
         [0.4962, 0.5038, 0.0000, 0.0000],
         [0.4524, 0.0000, 0.0000, 0.5476],
         [0.0000, 0.0000, 0.5224, 0.4776],
         [0.0000, 0.6420, 0.0000, 0.3580],
         [0.4361, 0.0000, 0.0000, 0.5639]], grad_fn=<ViewBackward0>),
 torch.Size([8, 4]))

In [23]:
expert_mask = (indices == 0).any(-1)
expert_mask.shape

torch.Size([2, 4])

In [24]:
expert_mask_flat = expert_mask.view(-1)
expert_mask_flat

tensor([ True, False, False,  True,  True, False, False,  True])

In [26]:
expert_input = x_flat[expert_mask_flat]
expert_input, expert_input.shape

(tensor([[ 0.8376,  1.6742,  1.5400, -2.3451,  0.9166, -0.6159,  1.2018,  1.5354,
          -0.1724,  0.3897],
         [ 0.7512, -0.5504, -0.3900,  0.9416,  0.2048, -0.8649,  0.7229,  0.0534,
          -0.2791, -0.3948],
         [-0.1167,  1.1454,  0.4875,  0.6351, -1.6597, -0.2928,  0.0223, -0.8946,
           0.8921,  1.7179],
         [-2.2027,  0.9008,  1.6723,  1.0737,  0.5271,  0.8475, -2.7148, -2.8659,
           0.0842,  1.7391]]),
 torch.Size([4, 10]))

In [27]:
expert = nn.Linear(10, 10)

In [28]:
export_output = expert(expert_input)
export_output, export_output.shape

(tensor([[ 0.3335, -0.4578, -0.3970, -1.2636, -0.0809, -0.0451, -0.6618, -0.3685,
          -0.5423, -0.0130],
         [-0.3068,  0.1800,  0.3231,  0.2886,  0.1836, -0.3680, -0.1322, -0.3956,
          -0.1619,  0.0895],
         [ 0.5858, -0.9985,  0.8756, -0.6312,  1.3548, -0.8566, -0.4663, -0.0338,
           0.7733, -0.5747],
         [ 1.7103, -0.2132,  0.8046,  0.3739,  0.7678, -0.9101, -0.9059, -0.2524,
           0.1489,  0.6372]], grad_fn=<AddmmBackward0>),
 torch.Size([4, 10]))

In [34]:
gate_scores = sparse_logits_flat[expert_mask_flat, 0].unsqueeze(1)
gate_scores

tensor([[0.3295],
        [0.4962],
        [0.4524],
        [0.4361]], grad_fn=<UnsqueezeBackward0>)

In [35]:
weighted_output = export_output * gate_scores
weighted_output, weighted_output.shape

(tensor([[ 0.1099, -0.1508, -0.1308, -0.4163, -0.0267, -0.0149, -0.2180, -0.1214,
          -0.1787, -0.0043],
         [-0.1522,  0.0893,  0.1603,  0.1432,  0.0911, -0.1826, -0.0656, -0.1963,
          -0.0803,  0.0444],
         [ 0.2651, -0.4518,  0.3962, -0.2856,  0.6130, -0.3876, -0.2110, -0.0153,
           0.3499, -0.2600],
         [ 0.7458, -0.0930,  0.3509,  0.1631,  0.3348, -0.3969, -0.3950, -0.1101,
           0.0649,  0.2779]], grad_fn=<MulBackward0>),
 torch.Size([4, 10]))

In [37]:
final_outputs = torch.zeros_like(x)
final_outputs

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [38]:
final_outputs[expert_mask] += weighted_output
final_outputs

tensor([[[ 0.1099, -0.1508, -0.1308, -0.4163, -0.0267, -0.0149, -0.2180,
          -0.1214, -0.1787, -0.0043],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [-0.1522,  0.0893,  0.1603,  0.1432,  0.0911, -0.1826, -0.0656,
          -0.1963, -0.0803,  0.0444]],

        [[ 0.2651, -0.4518,  0.3962, -0.2856,  0.6130, -0.3876, -0.2110,
          -0.0153,  0.3499, -0.2600],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000],
         [ 0.7458, -0.0930,  0.3509,  0.1631,  0.3348, -0.3969, -0.3950,
          -0.1101,  0.0649,  0.2779]]], grad_fn=<IndexPutBackward0>)