# 🏗️ DeepSeek-Coder-V2: Mixture-of-Experts (MoE) Architecture

## 🎯 Learning Objectives

Hiểu sâu về kiến trúc **Mixture-of-Experts (MoE)** được sử dụng trong DeepSeek-Coder-V2, bao gồm:

1. **MoE Fundamentals**: Nguyên lý hoạt động và architecture design
2. **Parameter Efficiency**: Làm thế nào MoE đạt efficiency với sparse activation
3. **Routing Mechanism**: Expert selection và load balancing
4. **Implementation Details**: Code implementation từ cơ bản đến nâng cao
5. **Performance Analysis**: So sánh MoE vs Dense models

## 📚 Paper References

**Section 3.2: Model Architecture**
> "Our architecture aligns with that of DeepSeekV2. The hyperparameters settings, 16B and 236B, correspond to those used in DeepSeek-V2-Lite and DeepSeek-V2, respectively."

**Key MoE Statistics from Paper:**
- DeepSeek-Coder-V2: **236B total params**, **21B active params** (8.9% efficiency)
- DeepSeek-Coder-V2-Lite: **16B total params**, **2.4B active params** (15% efficiency)

## 🔧 Environment Setup

In [None]:
# Core dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import math
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("🚀 MoE Architecture Learning Environment Ready!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 🧠 MoE Theory Deep Dive

### 💡 What is Mixture-of-Experts?

**Mixture-of-Experts** là một kiến trúc neural network sử dụng multiple specialized sub-networks (experts) thay vì một dense network lớn.

### 🔑 Key Concepts:

1. **Experts**: Các sub-networks chuyên biệt
2. **Gating Network**: Router quyết định expert nào được activate
3. **Sparse Activation**: Chỉ một subset experts được sử dụng cho mỗi input
4. **Load Balancing**: Đảm bảo experts được sử dụng đều

### 📊 Mathematical Foundation:

Cho input $x$, MoE output được tính:

$$y = \sum_{i=1}^{n} G(x)_i \cdot E_i(x)$$

Trong đó:
- $G(x)_i$: Gating weight cho expert $i$
- $E_i(x)$: Output của expert $i$
- $n$: Số lượng experts

In [None]:
class Expert(nn.Module):
    """Single Expert trong MoE Layer
    
    Mỗi expert là một feed-forward network đơn giản
    tương tự như FFN trong Transformer
    """
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        
        # Two-layer MLP với SwiGLU activation (như trong DeepSeek)
        self.w1 = nn.Linear(d_model, d_ff, bias=False)  # Gate projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)  # Down projection
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # Up projection
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """SwiGLU activation: SiLU(x @ W1) * (x @ W3) @ W2"""
        gate = F.silu(self.w1(x))  # SiLU activation
        up = self.w3(x)
        return self.w2(self.dropout(gate * up))

class TopKGating(nn.Module):
    """Top-K Gating mechanism cho MoE
    
    Chọn top-k experts cho mỗi token và tính gating weights
    """
    
    def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Gating network: simple linear layer
        self.gate = nn.Linear(d_model, num_experts, bias=False)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Compute top-k gating
        
        Returns:
            gates: Gating weights [batch_size, seq_len, top_k]
            indices: Expert indices [batch_size, seq_len, top_k]  
            load: Load balancing loss
        """
        batch_size, seq_len, d_model = x.shape
        
        # Compute gating logits
        logits = self.gate(x)  # [batch_size, seq_len, num_experts]
        
        # Top-k selection
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        
        # Softmax over top-k experts
        top_k_gates = F.softmax(top_k_logits, dim=-1)
        
        # Load balancing loss (auxiliary loss để đảm bảo experts được dùng đều)
        gates_mean = F.softmax(logits, dim=-1).mean(dim=[0, 1])  # [num_experts]
        load_loss = self.num_experts * torch.sum(gates_mean * gates_mean)
        
        return top_k_gates, top_k_indices, load_loss

# Demo basic components
print("🧪 Testing MoE Components:")
print("=" * 40)

# Test Expert
d_model, d_ff = 512, 2048
expert = Expert(d_model, d_ff)
x = torch.randn(2, 10, d_model)  # [batch=2, seq_len=10, d_model=512]
expert_out = expert(x)
print(f"✅ Expert output shape: {expert_out.shape}")

# Test Gating
num_experts, top_k = 8, 2
gating = TopKGating(d_model, num_experts, top_k)
gates, indices, load_loss = gating(x)
print(f"✅ Gating weights shape: {gates.shape}")
print(f"✅ Expert indices shape: {indices.shape}")
print(f"✅ Load balancing loss: {load_loss.item():.4f}")
print(f"📊 Selected experts for first token: {indices[0, 0].tolist()}")
print(f"📊 Gating weights for first token: {gates[0, 0].tolist()}")

## 🏗️ Complete MoE Layer Implementation

### 📋 DeepSeek-V2 MoE Architecture Details:

Theo paper, DeepSeek-V2 sử dụng:
- **Multi-head Latent Attention (MLA)**: Efficient attention mechanism
- **DeepSeekMoE**: Shared experts + Routed experts
- **Expert specialization**: Different experts cho different types of tokens

In [None]:
class MoELayer(nn.Module):
    """Complete MoE Layer Implementation
    
    Based on DeepSeek-V2 architecture với shared + routed experts
    """
    
    def __init__(
        self, 
        d_model: int, 
        d_ff: int,
        num_experts: int,
        num_shared_experts: int = 2,
        top_k: int = 6,  # DeepSeek-V2 uses top-6
        dropout: float = 0.1,
        expert_capacity_factor: float = 1.25
    ):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_experts = num_experts
        self.num_shared_experts = num_shared_experts
        self.top_k = top_k
        self.expert_capacity_factor = expert_capacity_factor
        
        # Shared experts (always activated)
        self.shared_experts = nn.ModuleList([
            Expert(d_model, d_ff, dropout) for _ in range(num_shared_experts)
        ])
        
        # Routed experts (sparsely activated)
        self.routed_experts = nn.ModuleList([
            Expert(d_model, d_ff, dropout) for _ in range(num_experts)
        ])
        
        # Gating network
        self.gate = TopKGating(d_model, num_experts, top_k)
        
        # Layer norm
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through MoE layer
        
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            
        Returns:
            output: MoE output [batch_size, seq_len, d_model]
            aux_loss: Auxiliary loss for load balancing
        """
        batch_size, seq_len, d_model = x.shape
        
        # 1. Shared experts (always computed)
        shared_output = torch.zeros_like(x)
        for expert in self.shared_experts:
            shared_output += expert(x)
        
        # 2. Routed experts (sparsely computed)
        gates, indices, aux_loss = self.gate(x)
        
        # Reshape for expert computation
        flat_x = x.view(-1, d_model)  # [batch_size * seq_len, d_model]
        flat_gates = gates.view(-1, self.top_k)  # [batch_size * seq_len, top_k]
        flat_indices = indices.view(-1, self.top_k)  # [batch_size * seq_len, top_k]
        
        # Compute routed expert outputs
        routed_output = torch.zeros_like(flat_x)
        
        for i, expert in enumerate(self.routed_experts):
            # Find tokens routed to this expert
            expert_mask = (flat_indices == i)
            if expert_mask.any():
                # Get token indices and weights for this expert
                token_indices, expert_positions = torch.where(expert_mask)
                
                if len(token_indices) > 0:
                    # Expert computation for selected tokens
                    expert_tokens = flat_x[token_indices]
                    expert_output = expert(expert_tokens)
                    
                    # Weight by gating values
                    weights = flat_gates[token_indices, expert_positions].unsqueeze(-1)
                    weighted_output = expert_output * weights
                    
                    # Accumulate in routed_output
                    routed_output[token_indices] += weighted_output
        
        # Reshape back
        routed_output = routed_output.view(batch_size, seq_len, d_model)
        
        # 3. Combine shared + routed outputs
        output = shared_output + routed_output
        
        # 4. Residual connection và layer norm
        output = self.norm(x + output)
        
        return output, aux_loss
    
    def get_expert_usage_stats(self, x: torch.Tensor) -> Dict[str, float]:
        """Analyze expert usage patterns"""
        with torch.no_grad():
            gates, indices, _ = self.gate(x)
            
            # Count expert usage
            expert_counts = torch.zeros(self.num_experts)
            for i in range(self.num_experts):
                expert_counts[i] = (indices == i).sum().item()
            
            total_selections = indices.numel()
            usage_percentages = expert_counts / total_selections * 100
            
            return {
                'expert_usage': usage_percentages.tolist(),
                'max_usage': usage_percentages.max().item(),
                'min_usage': usage_percentages.min().item(),
                'usage_std': usage_percentages.std().item(),
                'total_selections': total_selections
            }

# Test complete MoE layer
print("\n🏗️ Testing Complete MoE Layer:")
print("=" * 50)

# MoE configuration similar to DeepSeek-V2
moe_config = {
    'd_model': 512,
    'd_ff': 2048,
    'num_experts': 64,  # DeepSeek-V2 uses 160 experts
    'num_shared_experts': 2,
    'top_k': 6,
    'dropout': 0.1
}

moe_layer = MoELayer(**moe_config)

# Test forward pass
batch_size, seq_len = 4, 32
x = torch.randn(batch_size, seq_len, moe_config['d_model'])

output, aux_loss = moe_layer(x)
print(f"✅ MoE output shape: {output.shape}")
print(f"✅ Auxiliary loss: {aux_loss.item():.6f}")

# Analyze expert usage
usage_stats = moe_layer.get_expert_usage_stats(x)
print(f"📊 Expert usage statistics:")
print(f"   Max usage: {usage_stats['max_usage']:.2f}%")
print(f"   Min usage: {usage_stats['min_usage']:.2f}%")
print(f"   Usage std: {usage_stats['usage_std']:.2f}%")
print(f"   Total selections: {usage_stats['total_selections']}")

# Calculate parameter efficiency
total_params = sum(p.numel() for p in moe_layer.parameters())
shared_params = sum(p.numel() for expert in moe_layer.shared_experts for p in expert.parameters())
single_expert_params = sum(p.numel() for p in moe_layer.routed_experts[0].parameters())
active_expert_params = moe_config['top_k'] * single_expert_params
gate_params = sum(p.numel() for p in moe_layer.gate.parameters())

active_params = shared_params + active_expert_params + gate_params
efficiency = active_params / total_params * 100

print(f"\n⚡ Parameter Efficiency:")
print(f"   Total parameters: {total_params:,}")
print(f"   Active parameters: {active_params:,}")
print(f"   Efficiency: {efficiency:.1f}%")
print(f"   Similar to DeepSeek-Coder-V2: 8.9% (21B/236B)")

## 📊 MoE vs Dense Model Comparison

### 🔬 Performance Analysis

So sánh hiệu suất tính toán và memory giữa MoE và Dense models

In [None]:
class DenseFFN(nn.Module):
    """Dense Feed-Forward Network để so sánh với MoE"""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.silu(self.w1(x))
        up = self.w3(x)
        output = self.w2(self.dropout(gate * up))
        return self.norm(x + output)

def benchmark_models(d_model: int, seq_len: int, batch_size: int = 1, num_trials: int = 10):
    """Benchmark MoE vs Dense models"""
    
    # Model configurations
    configs = {
        'MoE-64': {
            'model': MoELayer(d_model, d_model*4, num_experts=64, top_k=6),
            'type': 'MoE'
        },
        'MoE-32': {
            'model': MoELayer(d_model, d_model*4, num_experts=32, top_k=4),
            'type': 'MoE'
        },
        'Dense-Large': {
            'model': DenseFFN(d_model, d_model*8),  # Larger to match MoE capacity
            'type': 'Dense'
        },
        'Dense-Small': {
            'model': DenseFFN(d_model, d_model*4),
            'type': 'Dense'
        }
    }
    
    results = {}
    x = torch.randn(batch_size, seq_len, d_model)
    
    for name, config in configs.items():
        model = config['model']
        model.eval()
        
        # Parameter count
        total_params = sum(p.numel() for p in model.parameters())
        
        # Memory usage (approximate)
        if config['type'] == 'MoE':
            # Only count active parameters for memory
            shared_params = sum(p.numel() for expert in model.shared_experts for p in expert.parameters())
            single_expert_params = sum(p.numel() for p in model.routed_experts[0].parameters())
            active_params = shared_params + model.top_k * single_expert_params
        else:
            active_params = total_params
        
        # Speed benchmark
        import time
        times = []
        
        with torch.no_grad():
            # Warmup
            for _ in range(3):
                if config['type'] == 'MoE':
                    _ = model(x)
                else:
                    _ = model(x)
            
            # Benchmark
            for _ in range(num_trials):
                start_time = time.time()
                if config['type'] == 'MoE':
                    output, aux_loss = model(x)
                else:
                    output = model(x)
                end_time = time.time()
                times.append(end_time - start_time)
        
        avg_time = np.mean(times) * 1000  # Convert to ms
        
        results[name] = {
            'total_params': total_params,
            'active_params': active_params,
            'efficiency': active_params / total_params * 100,
            'avg_time_ms': avg_time,
            'type': config['type']
        }
    
    return results

# Run benchmark
print("🏃 Running MoE vs Dense Benchmark...")
print("=" * 50)

benchmark_results = benchmark_models(
    d_model=512, 
    seq_len=128, 
    batch_size=4, 
    num_trials=20
)

# Display results
print(f"{'Model':<15} {'Type':<6} {'Total Params':<12} {'Active Params':<12} {'Efficiency':<10} {'Time (ms)':<10}")
print("-" * 80)

for name, stats in benchmark_results.items():
    print(f"{name:<15} {stats['type']:<6} {stats['total_params']:<12,} "
          f"{stats['active_params']:<12,} {stats['efficiency']:<10.1f}% "
          f"{stats['avg_time_ms']:<10.2f}")

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

models = list(benchmark_results.keys())
total_params = [benchmark_results[m]['total_params']/1e6 for m in models]  # Convert to millions
active_params = [benchmark_results[m]['active_params']/1e6 for m in models]
times = [benchmark_results[m]['avg_time_ms'] for m in models]
colors = ['red' if benchmark_results[m]['type'] == 'MoE' else 'blue' for m in models]

# Parameter comparison
x = np.arange(len(models))
width = 0.35

axes[0].bar(x - width/2, total_params, width, label='Total', alpha=0.7)
axes[0].bar(x + width/2, active_params, width, label='Active', alpha=0.7)
axes[0].set_xlabel('Models')
axes[0].set_ylabel('Parameters (Millions)')
axes[0].set_title('Parameter Count Comparison')
axes[0].set_xticks(x)
axes[0].set_xticklabels(models, rotation=45)
axes[0].legend()

# Efficiency
efficiency = [benchmark_results[m]['efficiency'] for m in models]
bars = axes[1].bar(models, efficiency, color=colors, alpha=0.7)
axes[1].set_ylabel('Parameter Efficiency (%)')
axes[1].set_title('Parameter Efficiency')
axes[1].tick_params(axis='x', rotation=45)

# Add DeepSeek-V2 reference line
axes[1].axhline(y=8.9, color='orange', linestyle='--', alpha=0.8, label='DeepSeek-V2 (8.9%)')
axes[1].legend()

# Timing
bars = axes[2].bar(models, times, color=colors, alpha=0.7)
axes[2].set_ylabel('Inference Time (ms)')
axes[2].set_title('Inference Speed')
axes[2].tick_params(axis='x', rotation=45)

# Add legend for colors
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='red', alpha=0.7, label='MoE'),
                   Patch(facecolor='blue', alpha=0.7, label='Dense')]
fig.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()

print("\n🔍 Key Insights:")
print("• MoE models achieve higher capacity with similar active parameters")
print("• Parameter efficiency varies with expert count and top-k")
print("• Inference speed depends on routing overhead vs computation savings")
print("• DeepSeek-V2's 8.9% efficiency is achieved through careful architecture design")

## 🎯 Expert Specialization Analysis

### 🧪 Understanding How Experts Specialize

Phân tích cách các experts trong MoE học specialization cho different types of tokens

In [None]:
def analyze_expert_specialization(moe_layer: MoELayer, test_samples: Dict[str, torch.Tensor]):
    """Analyze expert specialization patterns
    
    Args:
        moe_layer: Trained MoE layer
        test_samples: Dict of different input types
    """
    
    specialization_stats = {}
    
    with torch.no_grad():
        for sample_type, sample_data in test_samples.items():
            gates, indices, _ = moe_layer.gate(sample_data)
            
            # Count expert usage for this sample type
            expert_counts = torch.zeros(moe_layer.num_experts)
            for i in range(moe_layer.num_experts):
                expert_counts[i] = (indices == i).sum().item()
            
            # Normalize to percentages
            total_selections = indices.numel()
            usage_percentages = expert_counts / total_selections * 100
            
            specialization_stats[sample_type] = {
                'usage': usage_percentages.numpy(),
                'top_experts': torch.topk(usage_percentages, 5).indices.tolist(),
                'concentration': torch.std(usage_percentages).item()
            }
    
    return specialization_stats

def create_synthetic_code_patterns(d_model: int, seq_len: int = 32, batch_size: int = 8):
    """Create synthetic data patterns to simulate different code types"""
    
    patterns = {}
    
    # Different "code" patterns (simulated với different distributions)
    
    # 1. "Function definitions" - structured patterns
    func_pattern = torch.randn(batch_size, seq_len, d_model)
    func_pattern[:, :5] *= 2.0  # Strong start pattern
    patterns['functions'] = func_pattern
    
    # 2. "Control flow" - repetitive patterns  
    control_pattern = torch.randn(batch_size, seq_len, d_model)
    control_pattern[:, ::3] *= 1.5  # Periodic patterns
    patterns['control_flow'] = control_pattern
    
    # 3. "Data structures" - complex patterns
    data_pattern = torch.randn(batch_size, seq_len, d_model)
    data_pattern += torch.sin(torch.arange(seq_len).float()).unsqueeze(0).unsqueeze(-1) * 0.5
    patterns['data_structures'] = data_pattern
    
    # 4. "Comments" - sparse patterns
    comment_pattern = torch.randn(batch_size, seq_len, d_model) * 0.5
    patterns['comments'] = comment_pattern
    
    # 5. "Math expressions" - dense patterns
    math_pattern = torch.randn(batch_size, seq_len, d_model) * 1.5
    math_pattern += torch.cos(torch.arange(seq_len).float()).unsqueeze(0).unsqueeze(-1) * 0.3
    patterns['math_expressions'] = math_pattern
    
    return patterns

# Create test patterns
print("🎨 Creating Synthetic Code Patterns...")
test_patterns = create_synthetic_code_patterns(d_model=512)

# Initialize a fresh MoE layer for specialization analysis
specialization_moe = MoELayer(
    d_model=512,
    d_ff=2048,
    num_experts=32,  # Fewer experts for clearer visualization
    top_k=4
)

# Simulate "training" by running patterns through the model multiple times
print("🏋️ Simulating Expert Specialization Training...")
specialization_moe.train()

# Simple training simulation
optimizer = torch.optim.Adam(specialization_moe.parameters(), lr=0.001)
for epoch in range(50):  # Quick training simulation
    total_loss = 0
    for pattern_type, pattern_data in test_patterns.items():
        optimizer.zero_grad()
        output, aux_loss = specialization_moe(pattern_data)
        
        # Simple reconstruction loss + auxiliary loss
        recon_loss = F.mse_loss(output, pattern_data)
        loss = recon_loss + 0.01 * aux_loss
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {total_loss:.4f}")

# Analyze specialization
specialization_moe.eval()
print("\n🔬 Analyzing Expert Specialization...")
specialization_results = analyze_expert_specialization(specialization_moe, test_patterns)

# Visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

pattern_types = list(test_patterns.keys())
num_experts = specialization_moe.num_experts

for i, pattern_type in enumerate(pattern_types):
    if i < len(axes):
        usage = specialization_results[pattern_type]['usage']
        top_experts = specialization_results[pattern_type]['top_experts']
        concentration = specialization_results[pattern_type]['concentration']
        
        # Bar plot of expert usage
        bars = axes[i].bar(range(num_experts), usage, alpha=0.7)
        
        # Highlight top experts
        for expert_idx in top_experts[:3]:  # Top 3
            bars[expert_idx].set_color('red')
            bars[expert_idx].set_alpha(0.9)
        
        axes[i].set_title(f'{pattern_type.replace("_", " ").title()}\nConcentration: {concentration:.2f}')
        axes[i].set_xlabel('Expert Index')
        axes[i].set_ylabel('Usage (%)')
        axes[i].set_ylim(0, max(usage) * 1.1)

# Remove empty subplot
if len(pattern_types) < len(axes):
    axes[-1].remove()

plt.suptitle('Expert Specialization Patterns\n(Red bars = Top 3 experts for each pattern type)', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Summary statistics
print("\n📊 Specialization Summary:")
print("=" * 50)

for pattern_type, stats in specialization_results.items():
    top_experts = stats['top_experts'][:3]
    concentration = stats['concentration']
    max_usage = max(stats['usage'])
    
    print(f"{pattern_type.replace('_', ' ').title():>15}: "
          f"Top experts: {top_experts}, "
          f"Max usage: {max_usage:.1f}%, "
          f"Concentration: {concentration:.2f}")

print("\n🔍 Specialization Insights:")
print("• Higher concentration = more specialized experts")
print("• Different patterns activate different expert subsets")
print("• Real DeepSeek-V2 shows similar specialization on code vs text vs math")
print("• Specialization improves with more training and better routing")

## 🛠️ Advanced MoE Techniques

### 🔧 Load Balancing và Expert Dropout

Implementing advanced techniques để improve MoE training stability

In [None]:
class AdvancedMoELayer(nn.Module):
    """Advanced MoE with enhanced load balancing and expert dropout"""
    
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int,
        top_k: int = 2,
        dropout: float = 0.1,
        expert_dropout: float = 0.1,
        load_balancing_loss_coef: float = 0.01,
        router_z_loss_coef: float = 0.001
    ):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        self.expert_dropout = expert_dropout
        self.load_balancing_loss_coef = load_balancing_loss_coef
        self.router_z_loss_coef = router_z_loss_coef
        
        # Experts
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff, dropout) for _ in range(num_experts)
        ])
        
        # Enhanced router với noise for better load balancing
        self.router = nn.Linear(d_model, num_experts, bias=False)
        self.noise_generator = nn.Linear(d_model, num_experts, bias=False)
        
        # Layer norm
        self.norm = nn.LayerNorm(d_model)
        
    def add_noise_to_logits(self, logits: torch.Tensor, noise_epsilon: float = 1e-2) -> torch.Tensor:
        """Add noise to router logits for better load balancing"""
        if self.training:
            noise = torch.randn_like(logits) * noise_epsilon
            return logits + noise
        return logits
    
    def compute_routing_weights(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """Compute routing weights với enhanced load balancing"""
        batch_size, seq_len, d_model = x.shape
        
        # Router logits
        logits = self.router(x)  # [batch_size, seq_len, num_experts]
        
        # Add noise during training
        logits = self.add_noise_to_logits(logits)
        
        # Top-k selection
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        top_k_weights = F.softmax(top_k_logits, dim=-1)
        
        # Compute auxiliary losses
        aux_losses = {}
        
        # 1. Load balancing loss
        router_probs = F.softmax(logits, dim=-1)
        expert_usage = router_probs.mean(dim=[0, 1])  # [num_experts]
        
        # Coefficient of variation for load balancing
        cv_squared = (expert_usage.var() / expert_usage.mean().clamp(min=1e-10))
        load_balancing_loss = self.num_experts * cv_squared
        aux_losses['load_balancing'] = load_balancing_loss
        
        # 2. Router z-loss (để prevent large logits)
        router_z_loss = torch.logsumexp(logits, dim=-1).mean()
        aux_losses['router_z'] = router_z_loss
        
        return top_k_weights, top_k_indices, aux_losses
    
    def expert_dropout_mask(self, expert_indices: torch.Tensor) -> torch.Tensor:
        """Apply expert dropout during training"""
        if self.training and self.expert_dropout > 0:
            # Randomly drop some expert selections
            dropout_mask = torch.rand_like(expert_indices.float()) > self.expert_dropout
            return dropout_mask
        return torch.ones_like(expert_indices, dtype=torch.bool)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Forward pass với advanced MoE techniques"""
        batch_size, seq_len, d_model = x.shape
        
        # Compute routing
        weights, indices, aux_losses = self.compute_routing_weights(x)
        
        # Apply expert dropout
        dropout_mask = self.expert_dropout_mask(indices)
        
        # Flatten for expert computation
        flat_x = x.view(-1, d_model)
        flat_weights = weights.view(-1, self.top_k)
        flat_indices = indices.view(-1, self.top_k)
        flat_dropout_mask = dropout_mask.view(-1, self.top_k)
        
        # Expert computation
        expert_outputs = torch.zeros_like(flat_x)
        
        for expert_idx, expert in enumerate(self.experts):
            # Find tokens assigned to this expert
            expert_mask = (flat_indices == expert_idx) & flat_dropout_mask
            
            if expert_mask.any():
                token_indices, k_indices = torch.where(expert_mask)
                
                if len(token_indices) > 0:
                    # Expert computation
                    expert_tokens = flat_x[token_indices]
                    expert_output = expert(expert_tokens)
                    
                    # Apply weights
                    expert_weights = flat_weights[token_indices, k_indices].unsqueeze(-1)
                    weighted_output = expert_output * expert_weights
                    
                    # Accumulate
                    expert_outputs[token_indices] += weighted_output
        
        # Reshape and apply residual + norm
        expert_outputs = expert_outputs.view(batch_size, seq_len, d_model)
        output = self.norm(x + expert_outputs)
        
        # Combine auxiliary losses
        total_aux_loss = (
            self.load_balancing_loss_coef * aux_losses['load_balancing'] +
            self.router_z_loss_coef * aux_losses['router_z']
        )
        
        aux_losses['total'] = total_aux_loss
        
        return output, aux_losses

# Test advanced MoE
print("🚀 Testing Advanced MoE Layer:")
print("=" * 40)

advanced_moe = AdvancedMoELayer(
    d_model=512,
    d_ff=2048,
    num_experts=32,
    top_k=4,
    expert_dropout=0.1,
    load_balancing_loss_coef=0.01,
    router_z_loss_coef=0.001
)

# Test forward pass
x = torch.randn(4, 32, 512)
output, aux_losses = advanced_moe(x)

print(f"✅ Output shape: {output.shape}")
print(f"📊 Auxiliary losses:")
for loss_name, loss_value in aux_losses.items():
    print(f"   {loss_name}: {loss_value.item():.6f}")

# Compare load balancing với và without advanced techniques
def compare_load_balancing(num_trials: int = 100):
    """Compare load balancing between basic and advanced MoE"""
    
    basic_moe = MoELayer(d_model=512, d_ff=2048, num_experts=32, top_k=4)
    advanced_moe_test = AdvancedMoELayer(d_model=512, d_ff=2048, num_experts=32, top_k=4)
    
    basic_usage_stds = []
    advanced_usage_stds = []
    
    for _ in range(num_trials):
        x = torch.randn(4, 32, 512)
        
        # Basic MoE
        with torch.no_grad():
            basic_stats = basic_moe.get_expert_usage_stats(x)
            basic_usage_stds.append(basic_stats['usage_std'])
        
        # Advanced MoE
        with torch.no_grad():
            advanced_moe_test.eval()
            _, indices_adv, _ = advanced_moe_test.compute_routing_weights(x)
            expert_counts = torch.zeros(32)
            for i in range(32):
                expert_counts[i] = (indices_adv == i).sum().item()
            usage_percentages = expert_counts / indices_adv.numel() * 100
            advanced_usage_stds.append(usage_percentages.std().item())
    
    return np.mean(basic_usage_stds), np.mean(advanced_usage_stds)

print("\n⚖️ Comparing Load Balancing...")
basic_std, advanced_std = compare_load_balancing(50)
print(f"Basic MoE usage std: {basic_std:.2f}%")
print(f"Advanced MoE usage std: {advanced_std:.2f}%")
print(f"Improvement: {((basic_std - advanced_std) / basic_std * 100):.1f}% reduction in usage variance")

print("\n🎯 Advanced MoE Benefits:")
print("• Better load balancing through noise injection")
print("• Expert dropout prevents overfitting")
print("• Router z-loss stabilizes training")
print("• Multiple auxiliary losses guide optimization")
print("• Similar techniques used in DeepSeek-V2 and other SOTA MoE models")

## 🏁 Summary & Key Takeaways

### 📋 What We Learned về MoE Architecture

1. **Core Concept**: MoE enables scaling model capacity without proportional compute increase
2. **Parameter Efficiency**: DeepSeek-V2 achieves 8.9% efficiency (21B/236B active/total)
3. **Expert Specialization**: Different experts learn to handle different types of input patterns
4. **Load Balancing**: Critical for preventing expert collapse và ensuring utilization
5. **Advanced Techniques**: Noise injection, expert dropout, multiple auxiliary losses

### 🔬 Research Directions

1. **Dynamic Expert Selection**: Adaptive top-k based on input complexity
2. **Hierarchical MoE**: Multi-level expert routing
3. **Cross-lingual Expert Sharing**: Experts specialized for different programming languages
4. **Hardware-aware MoE**: Optimizing cho specific accelerators

In [None]:
# Final MoE architecture summary
def create_moe_architecture_summary():
    """Create comprehensive MoE architecture summary"""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. DeepSeek-V2 MoE Configuration
    ax1 = axes[0, 0]
    models = ['DeepSeek-Coder-V2\n(236B)', 'DeepSeek-Coder-V2-Lite\n(16B)']
    total_params = [236, 16]
    active_params = [21, 2.4]
    
    x = np.arange(len(models))
    width = 0.35
    
    bars1 = ax1.bar(x - width/2, total_params, width, label='Total Parameters (B)', alpha=0.7, color='lightblue')
    bars2 = ax1.bar(x + width/2, active_params, width, label='Active Parameters (B)', alpha=0.7, color='darkblue')
    
    ax1.set_ylabel('Parameters (Billions)')
    ax1.set_title('DeepSeek-V2 MoE Parameter Efficiency')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models)
    ax1.legend()
    
    # Add efficiency labels
    for i, (total, active) in enumerate(zip(total_params, active_params)):
        efficiency = active / total * 100
        ax1.text(i, total + 5, f'{efficiency:.1f}%\nefficiency', 
                ha='center', va='bottom', fontweight='bold')
    
    # 2. Expert routing visualization
    ax2 = axes[0, 1]
    
    # Create mock routing heatmap
    num_tokens = 20
    num_experts = 16
    top_k = 4
    
    routing_matrix = np.zeros((num_tokens, num_experts))
    
    # Simulate realistic routing patterns
    for token in range(num_tokens):
        # Each token selects top_k experts
        selected_experts = np.random.choice(num_experts, top_k, replace=False)
        weights = np.random.dirichlet(np.ones(top_k))  # Softmax-like weights
        routing_matrix[token, selected_experts] = weights
    
    im = ax2.imshow(routing_matrix, cmap='Blues', aspect='auto')
    ax2.set_xlabel('Expert Index')
    ax2.set_ylabel('Token Index')
    ax2.set_title('Expert Routing Pattern\n(Darker = Higher Weight)')
    plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
    
    # 3. Load balancing comparison
    ax3 = axes[1, 0]
    
    # Simulate load balancing scenarios
    scenarios = ['Random\nRouting', 'Basic\nMoE', 'Advanced\nMoE\n(w/ Load Balancing)']
    usage_stds = [15.2, 8.7, 3.1]  # Lower is better
    
    bars = ax3.bar(scenarios, usage_stds, color=['red', 'orange', 'green'], alpha=0.7)
    ax3.set_ylabel('Expert Usage Std Dev (%)')
    ax3.set_title('Load Balancing Effectiveness\n(Lower = Better Balance)')
    
    # Add value labels
    for bar, value in zip(bars, usage_stds):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height + 0.2,
                f'{value}%', ha='center', va='bottom', fontweight='bold')
    
    # 4. Scaling comparison
    ax4 = axes[1, 1]
    
    model_sizes = [1, 7, 13, 30, 70, 175]  # Model sizes in billions
    dense_compute = [s**1.2 for s in model_sizes]  # Dense scaling (superlinear)
    moe_compute = [s**0.8 for s in model_sizes]    # MoE scaling (sublinear)
    
    ax4.plot(model_sizes, dense_compute, 'o-', label='Dense Models', linewidth=2, markersize=6)
    ax4.plot(model_sizes, moe_compute, 's-', label='MoE Models', linewidth=2, markersize=6)
    
    ax4.set_xlabel('Model Size (B parameters)')
    ax4.set_ylabel('Relative Compute Cost')
    ax4.set_title('Compute Scaling: Dense vs MoE')
    ax4.set_yscale('log')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.suptitle('MoE Architecture Deep Dive: Key Concepts & Benefits', 
                 fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.show()
    
    # Print key statistics
    print("🎯 MoE Architecture Key Numbers:")
    print("=" * 50)
    print(f"📊 DeepSeek-Coder-V2 Parameter Efficiency: {21/236*100:.1f}%")
    print(f"📊 DeepSeek-Coder-V2-Lite Parameter Efficiency: {2.4/16*100:.1f}%")
    print(f"🎯 Top-K Routing: K=6 for optimal quality/efficiency tradeoff")
    print(f"⚖️ Load Balancing: Critical for expert utilization")
    print(f"🚀 Compute Scaling: MoE enables sublinear scaling vs model size")
    
    print("\n💡 Key Implementation Insights:")
    print("• Shared + Routed experts architecture")
    print("• SwiGLU activation in experts")
    print("• Noise injection for load balancing")
    print("• Multiple auxiliary losses")
    print("• Expert dropout for regularization")
    print("• Hardware-efficient sparse computation")

create_moe_architecture_summary()

print("\n🎉 MoE Architecture Deep Dive Complete!")
print("\n📚 Further Reading:")
print("• Switch Transformer (Google, 2021)")
print("• GLaM: Efficient Scaling of Language Models (Google, 2021)")
print("• PaLM-2 Technical Report (Google, 2023)")
print("• DeepSeek-V2 Paper (2024)")
print("\n✨ Next: Explore YARN Context Extension! ✨")