In [50]:
import torch
from torch import nn
import torch.nn.functional as F

### Baisc MOE
* Single `nn.Linear` layer as the expert module
* `nn.Linear` layer as the dense router
* input with shape `(b, feature_in)`, output with shape `(b, feature_out)`

In [51]:
class BasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)
    
    def forward(self, x):
        return self.fc(x)

In [52]:
class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out, num_experts):
        super().__init__()
        self.gate = nn.Linear(feature_in, num_experts)
        # output shape (batch_size, num_experts)
        self.experts = nn.ModuleList(
            BasicExpert(
                feature_in, feature_out
            ) for _ in range(num_experts)
        )
    
    def forward(self, x):
        # x shape (batch_size, feature_in)
        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ] # each expert output in shape (batch_size, feature_out), unsqueeze to (batch_size, 1, feature_out)

        expert_output = torch.concat(expert_out_list, dim=1)
        # expert_output shape (b, num_experts, feature_out)

        #expert weights
        expert_weights = F.softmax(expert_weights, dim=1)
        expert_weights = expert_weights.unsqueeze(1)
        # exopert_weights shape (b, 1, num_experts)

        output = expert_weights @ expert_output
        return output.squeeze()

In [53]:
def test_basic_moe():
    x = torch.rand(10, 512)
    basic_moe = BasicMOE(512, 128, 4)
    output = basic_moe(x)
    print(output.shape)

test_basic_moe()

torch.Size([10, 128])


### Sparse MOE
* top-K router
* input with shape `(b, seq_len, hidden_dim)`

In [54]:
class MOEConfig:
    def __init__(self, hidden_dim, expert_num, top_k, shared_expert_num=2):
        self.hidden_dim = hidden_dim
        self.expert_num = expert_num
        self.top_k = top_k
        self.shared_expert_num = shared_expert_num

class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_num)
        self.top_k = config.top_k
        self.expert_num = config.expert_num
    
    def forward(self, x):
        router_logits = self.gate(x)
        # router_logits shape (b*seq_len, expert_num)
        router_probs = F.softmax(router_logits, dim=-1)

        router_weights, top_k_indices = router_probs.topk(k=self.top_k, dim=-1)
        # router_weight & top_k_indices shape (b*seq_len, top_k)

        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)

        expert_masks = F.one_hot(top_k_indices, self.expert_num)
        # expert_mask shape (b*seq_len, top_k, expert_num)
        expert_masks = expert_masks.permute(2, 1, 0)
        # expert_mask shape (expert_num, top_k, b*seq_len)
        
        return router_logits, router_weights, top_k_indices, expert_masks


class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_num = config.expert_num

        self.experts = nn.ModuleList([
            BasicExpert(
                config.hidden_dim,
                config.hidden_dim
            ) for _ in range(config.expert_num)
        ])
        self.router = MOERouter(config)

    def forward(self, x):
        # x shape (b, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.shape

        # reshape to (b*seq_len, hidden_dim) for token level calculation
        x = x.reshape(-1, hidden_dim)

        router_logtis, router_weights, top_k_indices, expert_masks = self.router(x)

        final_x = torch.zeros((batch_size*seq_len, hidden_dim), dtype=x.dtype, device=x.device)

        for expert_idx in range(self.expert_num):
            expert_layer = self.experts[expert_idx]

            current_expert_mask = expert_masks[expert_idx]
            # current_expert_mask shape (top_k, b*seq_len)
            
            top_idx, token_idx = torch.where(current_expert_mask)
            # top_idx shape: (top_k), token_idx shape: (b*seq_len)

            current_x = x[token_idx]
            # current_x shape (selected_token_num, hidden_dim)
            current_x = expert_layer(current_x)
            current_token_router_weight = router_weights[token_idx, top_idx]
            # current_token_router_weight shape (selected_token_num)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # current_token_router_weight shape (selected_token_num, 1)

            current_x *= current_token_router_weight

            final_x.index_add_(0, token_idx, current_x)

        final_x = final_x.reshape(batch_size, seq_len, hidden_dim)

        return final_x, router_logtis


In [55]:
def test_SparseMOE():
    x = torch.rand(10, 100, 512)
    config = MOEConfig(512, 16, 4)
    sparse_moe = SparseMOE(config)
    output = sparse_moe(x)
    print(output[0].shape, output[1].shape)

test_SparseMOE()

torch.Size([10, 100, 512]) torch.Size([1000, 16])


### DeepSeek MOE
* include shared expert

In [56]:
class SharedExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.routed_expert_moe = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(self.config.hidden_dim, self.config.hidden_dim)
            ]
        )

    def forward(self, x):
        # x shape: (b, seq_len, hidden_dim)
        b, seq_len, hidden_dim = x.shape

        shared_expert_output_list = [expert(x) for expert in self.shared_experts]
        shared_expert_output = torch.stack(shared_expert_output_list, dim=0)
        # shared_expert_output shape (shared_expert_num, b, seq_len, hidden_dim)

        shared_expert_output = shared_expert_output.sum(dim=0)
        # shared_expert_output shape (b, seq_len, hidden_dim)
        spared_moe_out, router_logits = self.routed_expert_moe(x)

        output = spared_moe_out + shared_expert_output

        return output, router_logits



In [57]:
def test_SharedExpertMOE():
    x = torch.rand(10, 100, 512)
    config = MOEConfig(512, 16, 4)
    sparse_moe = SharedExpertMOE(config)
    output = sparse_moe(x)
    print(output[0].shape, output[1].shape)

test_SharedExpertMOE()

torch.Size([10, 100, 512]) torch.Size([1000, 16])


### Load balancing loss

In [None]:
def switch_load_balancing_loss(router_logits, expert_num, top_k):
    router_probs = F.softmax(router_logits, dim=-1)
    # (b*seq_len, expert_num)
    
    _, selected_expert = torch.topk(router_probs, k=top_k, dim=-1)
    # (b*seq_len, top_k)
    expert_mask = F.one_hot(selected_expert, expert_num).to(torch.float)
    # (b*seq_len, top_k, expert_num)
    actual_load = expert_mask.mean(dim=0)
    # (top_k, expert_num)

    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * expert_num

    z_loss = torch.mean(router_logits**2)
    z_loss_weight = 0.01

    return aux_loss + z_loss * z_loss_weight


In [69]:
def test_moe_training():
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 1000
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                      expert_num=4,
                      top_k=2,
                      shared_expert_num=2)
    model = SharedExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_num, config.top_k)
        # Combined loss
        total_loss = 0.5 * mse_loss + 0.5 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

Batch 0, Loss: 1.7621 (MSE: 1.5131, Aux: 2.0110)
Batch 10, Loss: 1.7215 (MSE: 1.4367, Aux: 2.0063)
Batch 20, Loss: 1.6744 (MSE: 1.3442, Aux: 2.0047)
Batch 30, Loss: 1.6353 (MSE: 1.2658, Aux: 2.0049)
Batch 40, Loss: 1.6094 (MSE: 1.2162, Aux: 2.0025)
Batch 50, Loss: 1.5835 (MSE: 1.1636, Aux: 2.0034)
Batch 60, Loss: 1.5631 (MSE: 1.1247, Aux: 2.0015)
Batch 70, Loss: 1.5526 (MSE: 1.1025, Aux: 2.0027)
Batch 80, Loss: 1.5522 (MSE: 1.1022, Aux: 2.0022)
Batch 90, Loss: 1.5279 (MSE: 1.0545, Aux: 2.0012)
Batch 100, Loss: 1.5264 (MSE: 1.0506, Aux: 2.0022)
Batch 110, Loss: 1.5311 (MSE: 1.0607, Aux: 2.0014)
Batch 120, Loss: 1.5301 (MSE: 1.0579, Aux: 2.0022)
Batch 130, Loss: 1.5235 (MSE: 1.0443, Aux: 2.0027)
Batch 140, Loss: 1.5193 (MSE: 1.0363, Aux: 2.0024)
Batch 150, Loss: 1.5212 (MSE: 1.0413, Aux: 2.0010)
Batch 160, Loss: 1.5230 (MSE: 1.0424, Aux: 2.0037)
Batch 170, Loss: 1.5134 (MSE: 1.0244, Aux: 2.0025)
Batch 180, Loss: 1.5135 (MSE: 1.0214, Aux: 2.0055)
Batch 190, Loss: 1.5164 (MSE: 1.0314, Aux: