In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [20]:
class BasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)
        
    def forward(self, x):
        return self.fc(x)

In [31]:
class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out, num_experts):
        super().__init__()
        self.gate = nn.Linear(feature_in, num_experts)
        self.experts = nn.ModuleList(
            BasicExpert(feature_in, feature_out) for _ in range(num_experts)
        )
    def forward(self, x):
        #x shape batch_size, feature_in
        expert_weights = self.gate(x)       
        expert_out_list = [expert(x) for expert in self.experts]
        expert_output= [expert_output.unsqueeze(1) for expert_output in expert_out_list]
        expert_output = torch.concat(expert_output, dim = 1)
        #expert_output shape batch_size, num_experts, feature_out
        expert_weights = F.softmax(expert_weights, dim = 1)
        expert_weights = expert_weights.unsqueeze(1)
        #expert_weights shape batch_size, 1, num_experts
        print(expert_weights.shape, expert_output.shape)
        output = expert_weights @ expert_output
        return output.squeeze(1)

        
        

In [32]:
 def test_basic_moe():
    x = torch.rand(4, 512)
    basic_moe = BasicMOE(512, 128, 4)
    output = basic_moe(x)
    print(output.shape)
test_basic_moe()

torch.Size([4, 1, 4]) torch.Size([4, 4, 128])
torch.Size([4, 128])


# Sparse MOE

In [81]:
class MOEConfig:
    def __init__(self, hidden_dim, expert_number, top_k, shared_experts_number=2):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number
        
class MOERouter(nn.Module):
    def __init__(self, hidden_dim, expert_number, top_k):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, expert_number)
        self.expert_number = expert_number
        self.top_k = top_k
        
    def forward(self, x):
        router_logits = self.gate(x) #batch * seq_len, expert_number
        router_probs = F.softmax(router_logits, dim=1, dtype = torch.float32)
        print('router_probs',router_probs)
        router_weights , selected_expert_indices = torch.topk(router_probs, self.top_k, dim = -1)
        print("router_weights , selected_expert_indices", router_weights , selected_expert_indices)
        # shape batch * seq_len top_k
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
        router_weights = router_weights.to(x.dtype)
        print('router_weights',router_weights)
        expert_mask = F.one_hot(selected_expert_indices, num_classes=self.expert_number)
        print('expert_mask,',expert_mask)
        #expert masks: batch * seq_len top_k expert_number

        expert_mask = expert_mask.permute(2, 1, 0)
        #expert masks:  expert_number top_k batch * seq_len

        return router_logits, router_weights, selected_expert_indices, expert_mask
        
        

class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number
        
        self.experts = nn.ModuleList(
            BasicExpert(config.hidden_dim, config.hidden_dim) for _ in range(self.expert_number)
        )
        self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)

    
    def forward(self, x):
        #x shape batch ,seq_len, hidden_dim
        print('x.size()', x.size())
        batch_size , seq_len, hidden_dim = x.size()
        hidden_states = x.view(-1, hidden_dim)
        print('hidden_states',hidden_states.size())
        router_logits, router_weights, selected_expert_indices, expert_masks = self.router(hidden_states)
        
        final_hidden_states = torch.zeros((batch_size * seq_len,
                                          hidden_dim),
                                          dtype=hidden_states.dtype,
                                          device=hidden_states.device)
        
        
        for expert_idx in range(self.expert_number):
            print('-'*30)
            print(expert_idx)
            expert_layer = self.experts[expert_idx]
            #expert masks:  expert_number top_k batch * seq_len
            current_expert_mask = expert_masks[expert_idx]  
            print('current_expert_mask',current_expert_mask)
            router_weights_idx, top_x = torch.where(current_expert_mask)
            print('router_weights_idx, top_x',router_weights_idx, top_x)
            current_state = hidden_states.unsqueeze(0)[:, top_x,:].reshape(-1,hidden_dim)
            current_state = expert_layer(current_state)
            current_token_router_weight = router_weights[top_x, router_weights_idx]
            #selected_token_number
            current_token_router_weight=current_token_router_weight.unsqueeze(1)
            #selected_token_number,1                
            current_hidden_states = current_state * current_token_router_weight
            final_hidden_states.index_add_(0,top_x,current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
        return final_hidden_states, router_logits
            
            
        
    

In [82]:
 def test_token_level_moe():
    x = torch.rand(2,4,16)
    config = MOEConfig(hidden_dim=16,expert_number=3,top_k=2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)
test_token_level_moe()

x.size() torch.Size([2, 4, 16])
hidden_states torch.Size([8, 16])
router_probs tensor([[0.3625, 0.3123, 0.3252],
        [0.2982, 0.3196, 0.3822],
        [0.3835, 0.3093, 0.3072],
        [0.2782, 0.2733, 0.4485],
        [0.3135, 0.2490, 0.4375],
        [0.3384, 0.3509, 0.3106],
        [0.4166, 0.2528, 0.3306],
        [0.3010, 0.3389, 0.3601]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.3625, 0.3252],
        [0.3822, 0.3196],
        [0.3835, 0.3093],
        [0.4485, 0.2782],
        [0.4375, 0.3135],
        [0.3509, 0.3384],
        [0.4166, 0.3306],
        [0.3601, 0.3389]], grad_fn=<TopkBackward0>) tensor([[0, 2],
        [2, 1],
        [0, 1],
        [2, 0],
        [2, 0],
        [1, 0],
        [0, 2],
        [2, 1]])
router_weights tensor([[0.5272, 0.4728],
        [0.5446, 0.4554],
        [0.5536, 0.4464],
        [0.6172, 0.3828],
        [0.5825, 0.4175],
        [0.5090, 0.4910],
        [0.5576, 0.4424],
        [0.5151, 0.

# deepseek

In [88]:
class SharedExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.routed_expert_moe = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [BasicExpert(self.config.hidden_dim,self.config.hidden_dim) for  _ in range(self.config.shared_experts_number)]
        )
        
    def forward(self, x):
        batch_size, seq_len,hidden_dim = x.size()
        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]
        shared_expert_output = torch.stack(
            shared_experts_output_list,
            dim = 0
        )
        #shape shared_expert_num, batch_size, seql_len, hidden_dim
        shared_expert_output = shared_expert_output.sum(dim=0, keepdim=False)
        #batch_size, seql_len, hidden_dim
        sparse_moe_output, router_logits = self.routed_expert_moe(x)
        output = shared_expert_output + sparse_moe_output
        return output, router_logits
        

def test_share_expert_moe():
    x = torch.rand(2,4,16)
    config = MOEConfig(hidden_dim=16,expert_number=3,top_k=2)
    share_expert_moe = SharedExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)
test_share_expert_moe()


x.size() torch.Size([2, 4, 16])
hidden_states torch.Size([8, 16])
router_probs tensor([[0.5560, 0.2837, 0.1603],
        [0.5679, 0.2029, 0.2292],
        [0.5427, 0.2516, 0.2057],
        [0.5637, 0.2429, 0.1935],
        [0.6045, 0.2406, 0.1549],
        [0.4857, 0.2921, 0.2221],
        [0.5812, 0.2671, 0.1517],
        [0.5387, 0.2678, 0.1934]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.5560, 0.2837],
        [0.5679, 0.2292],
        [0.5427, 0.2516],
        [0.5637, 0.2429],
        [0.6045, 0.2406],
        [0.4857, 0.2921],
        [0.5812, 0.2671],
        [0.5387, 0.2678]], grad_fn=<TopkBackward0>) tensor([[0, 1],
        [0, 2],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]])
router_weights tensor([[0.6621, 0.3379],
        [0.7124, 0.2876],
        [0.6833, 0.3167],
        [0.6989, 0.3011],
        [0.7153, 0.2847],
        [0.6244, 0.3756],
        [0.6851, 0.3149],
        [0.6679, 0.

In [90]:
def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量
    
    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s]
    
    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()  # [b*s, num_experts]
    
    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts
    
    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)  # [num_experts]
    
    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数
    
    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                      expert_number=4,
                      top_k=2,
                      shared_experts_number=2)
    model = SharedExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

x.size() torch.Size([32, 16, 32])
hidden_states torch.Size([512, 32])
router_probs tensor([[0.2270, 0.1921, 0.3626, 0.2183],
        [0.1591, 0.3863, 0.0378, 0.4168],
        [0.3474, 0.3473, 0.0569, 0.2484],
        ...,
        [0.1297, 0.1019, 0.2882, 0.4802],
        [0.3793, 0.2226, 0.1629, 0.2352],
        [0.3899, 0.2622, 0.2743, 0.0737]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.3626, 0.2270],
        [0.4168, 0.3863],
        [0.3474, 0.3473],
        ...,
        [0.4802, 0.2882],
        [0.3793, 0.2352],
        [0.3899, 0.2743]], grad_fn=<TopkBackward0>) tensor([[2, 0],
        [3, 1],
        [0, 1],
        ...,
        [3, 2],
        [0, 3],
        [0, 2]])
router_weights tensor([[0.6150, 0.3850],
        [0.5190, 0.4810],
        [0.5001, 0.4999],
        ...,
        [0.6249, 0.3751],
        [0.6172, 0.3828],
        [0.5870, 0.4130]], grad_fn=<DivBackward0>)
expert_mask, tensor([[[0, 0, 1, 0],
         [1, 0, 0, 0]],

       

Batch 0, Loss: 1.8568 (MSE: 1.8368, Aux: 2.0075)
x.size() torch.Size([32, 16, 32])
hidden_states torch.Size([512, 32])
router_probs tensor([[0.2459, 0.1171, 0.2233, 0.4138],
        [0.2424, 0.3369, 0.2685, 0.1521],
        [0.2723, 0.3875, 0.1389, 0.2013],
        ...,
        [0.5422, 0.0581, 0.1143, 0.2854],
        [0.1079, 0.1485, 0.6312, 0.1124],
        [0.2543, 0.1522, 0.2081, 0.3853]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.4138, 0.2459],
        [0.3369, 0.2685],
        [0.3875, 0.2723],
        ...,
        [0.5422, 0.2854],
        [0.6312, 0.1485],
        [0.3853, 0.2543]], grad_fn=<TopkBackward0>) tensor([[3, 0],
        [1, 2],
        [1, 0],
        ...,
        [0, 3],
        [2, 1],
        [3, 0]])
router_weights tensor([[0.6272, 0.3728],
        [0.5565, 0.4435],
        [0.5873, 0.4127],
        ...,
        [0.6551, 0.3449],
        [0.8096, 0.1904],
        [0.6024, 0.3976]], grad_fn=<DivBackward0>)
expert_mask, tensor

router_weights_idx, top_x tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) tensor([ 22,  32,  37,  41,  51,  58,  59,  65,  74,  77, 107, 109, 110, 125,
        126, 136, 137, 138, 14

x.size() torch.Size([32, 16, 32])
hidden_states torch.Size([512, 32])
router_probs tensor([[0.2005, 0.4297, 0.2956, 0.0742],
        [0.1507, 0.3576, 0.1933, 0.2984],
        [0.4646, 0.2125, 0.1672, 0.1557],
        ...,
        [0.1882, 0.2805, 0.1277, 0.4036],
        [0.2701, 0.2106, 0.2141, 0.3053],
        [0.1841, 0.1322, 0.2715, 0.4121]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.4297, 0.2956],
        [0.3576, 0.2984],
        [0.4646, 0.2125],
        ...,
        [0.4036, 0.2805],
        [0.3053, 0.2701],
        [0.4121, 0.2715]], grad_fn=<TopkBackward0>) tensor([[1, 2],
        [1, 3],
        [0, 1],
        ...,
        [3, 1],
        [3, 0],
        [3, 2]])
router_weights tensor([[0.5924, 0.4076],
        [0.5451, 0.4549],
        [0.6862, 0.3138],
        ...,
        [0.5900, 0.4100],
        [0.5306, 0.4694],
        [0.6028, 0.3972]], grad_fn=<DivBackward0>)
expert_mask, tensor([[[0, 1, 0, 0],
         [0, 0, 1, 0]],

       

        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) tensor([  1,   9,  11,  22,  31,  35,  38,  49,  51,  54,  55,  58,  59,  64,
         65,  66,  72,  80,  81,  90,  92,  95,  97, 101, 102, 104, 108, 119,
        120, 121, 123, 124, 130, 133, 136, 138, 140, 149, 151, 152, 154, 156,
        157, 160, 161, 162, 167, 169, 174, 183, 188, 190, 192, 194, 198, 202,
        207, 208, 211, 213, 216, 218, 224, 226, 228, 232, 234, 236, 241, 248,
        256, 264, 266, 269, 271, 272, 278, 279, 283, 288, 292, 293, 295, 296,
        306, 307, 309, 312, 314, 319, 322, 324, 337, 338, 339, 349, 351, 352,
        354, 356, 362, 364, 365, 366, 367, 368, 369, 371, 375, 382, 385, 392,
        393, 395, 396, 397, 403, 406, 411, 417, 418, 419, 423, 427, 430, 434,
        437, 444, 445, 452, 453, 455, 456, 458, 463, 464, 466, 469, 471, 472,
        476, 478, 483, 485, 486, 495, 500, 501, 503, 508, 509,   0,   2,   3,
          8,  15,  17,  19,  23,  25,  28,  34,  37,  40,  43,  44,  46,  50,
         56,  63

router_probs tensor([[0.2081, 0.2674, 0.4540, 0.0705],
        [0.2070, 0.3297, 0.1213, 0.3421],
        [0.1600, 0.1249, 0.5229, 0.1922],
        ...,
        [0.2489, 0.3850, 0.1500, 0.2161],
        [0.1787, 0.2102, 0.1761, 0.4350],
        [0.3210, 0.2530, 0.2327, 0.1933]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.4540, 0.2674],
        [0.3421, 0.3297],
        [0.5229, 0.1922],
        ...,
        [0.3850, 0.2489],
        [0.4350, 0.2102],
        [0.3210, 0.2530]], grad_fn=<TopkBackward0>) tensor([[2, 1],
        [3, 1],
        [2, 3],
        ...,
        [1, 0],
        [3, 1],
        [0, 1]])
router_weights tensor([[0.6293, 0.3707],
        [0.5093, 0.4907],
        [0.7313, 0.2687],
        ...,
        [0.6074, 0.3926],
        [0.6742, 0.3258],
        [0.5593, 0.4407]], grad_fn=<DivBackward0>)
expert_mask, tensor([[[0, 0, 1, 0],
         [0, 1, 0, 0]],

        [[0, 0, 0, 1],
         [0, 1, 0, 0]],

        [[0, 0, 1, 0],
      

x.size() torch.Size([32, 16, 32])
hidden_states torch.Size([512, 32])
router_probs tensor([[0.2712, 0.2125, 0.2177, 0.2986],
        [0.2392, 0.2197, 0.3321, 0.2090],
        [0.2087, 0.4278, 0.1412, 0.2224],
        ...,
        [0.2339, 0.2253, 0.1824, 0.3585],
        [0.2695, 0.3315, 0.1657, 0.2334],
        [0.2350, 0.3029, 0.3160, 0.1461]], grad_fn=<SoftmaxBackward0>)
router_weights , selected_expert_indices tensor([[0.2986, 0.2712],
        [0.3321, 0.2392],
        [0.4278, 0.2224],
        ...,
        [0.3585, 0.2339],
        [0.3315, 0.2695],
        [0.3160, 0.3029]], grad_fn=<TopkBackward0>) tensor([[3, 0],
        [2, 0],
        [1, 3],
        ...,
        [3, 0],
        [1, 0],
        [2, 1]])
router_weights tensor([[0.5241, 0.4759],
        [0.5814, 0.4186],
        [0.6580, 0.3420],
        ...,
        [0.6052, 0.3948],
        [0.5516, 0.4484],
        [0.5106, 0.4894]], grad_fn=<DivBackward0>)
expert_mask, tensor([[[0, 0, 0, 1],
         [1, 0, 0, 0]],

       

router_weights_idx, top_x tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) tensor([  3,  15,  26,  28,  31,  32,  38,  39,  40,  46,  47,  48,  51,  58,
         59,  60,  66,  68,  74,  75,  81,  85,  89,  90,  96,  97, 117, 118,
        120, 121, 130, 132, 138, 140, 145, 146, 151, 