# ⚖️ Load Balancing Loss & Mutual Information trong MoL

## 🎯 Mục tiêu Học tập

Hiểu sâu về:
1. **Load Balancing Problem** trong Mixture-of-Experts systems
2. **Mutual Information-based Loss** để cải thiện component utilization
3. **Component Collapse** và cách ngăn chặn
4. **Conditional Computation** trong MoL
5. **Mathematical Foundations** của Information Theory trong ML

## 📖 Trích xuất từ Paper

### Section 2.2 - Load Balancing Loss:

> *"We propose techniques to retrieve the approximate top-k results using MoL with tight error bounds... enhanced by our proposed mutual information-based load balancing loss"*

> *"Our approximate top-k algorithms outperform baselines by up to 66× in latency while achieving >.99 recall rate compared to exact algorithms."*

### Key Concepts:

**Component Collapse**: Khi một vài components dominate, others become unused
**Load Balancing**: Ensure equal utilization across components
**Mutual Information**: Measure dependence between gating decisions and inputs

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Dict, Optional
import math
from scipy import stats
from sklearn.metrics import mutual_info_score
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔧 Device: {device}")

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 🔍 Phần 1: Understanding Component Collapse

### 📊 Vấn đề:

Trong Mixture-of-Experts systems, **component collapse** xảy ra khi:
- Một vài components được sử dụng heavily
- Majority components ít được sử dụng hoặc unused
- Model capacity bị waste, performance giảm

### 🎯 Giải pháp:
**Load Balancing Loss** khuyến khích:
- Uniform distribution của gating weights
- Equal utilization across components
- Better model capacity utilization

In [None]:
def demonstrate_component_collapse():
    """
    Demonstrate component collapse phenomenon
    """
    print("🚨 Component Collapse Demonstration")
    print("=" * 50)
    
    num_components = 8
    num_samples = 1000
    
    # Simulate different scenarios
    scenarios = {
        'uniform': np.random.dirichlet(np.ones(num_components), num_samples),
        'slightly_biased': np.random.dirichlet(np.array([2, 2, 2, 1, 1, 1, 1, 1]), num_samples),
        'heavily_biased': np.random.dirichlet(np.array([10, 5, 1, 1, 1, 1, 1, 1]), num_samples),
        'collapsed': np.random.dirichlet(np.array([100, 1, 1, 1, 1, 1, 1, 1]), num_samples)
    }
    
    results = {}
    
    for scenario_name, weights in scenarios.items():
        # Compute utilization statistics
        mean_weights = np.mean(weights, axis=0)
        std_weights = np.std(weights, axis=0)
        
        # Effective number of components (Perplexity)
        entropy = -np.sum(mean_weights * np.log(mean_weights + 1e-8))
        perplexity = np.exp(entropy)
        
        # Gini coefficient (inequality measure)
        sorted_weights = np.sort(mean_weights)
        n = len(sorted_weights)
        index = np.arange(1, n + 1)
        gini = (2 * np.sum(index * sorted_weights)) / (n * np.sum(sorted_weights)) - (n + 1) / n
        
        results[scenario_name] = {
            'mean_weights': mean_weights,
            'std_weights': std_weights,
            'perplexity': perplexity,
            'gini': gini,
            'entropy': entropy
        }
        
        print(f"\n📊 {scenario_name.upper()}:")
        print(f"   Effective components: {perplexity:.2f} / {num_components}")
        print(f"   Gini coefficient: {gini:.3f} (0=equal, 1=collapsed)")
        print(f"   Entropy: {entropy:.3f} (max={np.log(num_components):.3f})")
        print(f"   Weight distribution: {mean_weights}")
    
    return results, scenarios

collapse_results, collapse_scenarios = demonstrate_component_collapse()

## 🧮 Phần 2: Information Theory Foundations

### 📖 Lý thuyết:

**Entropy** (Shannon Entropy):
$$H(X) = -\sum_{i} p(x_i) \log p(x_i)$$

**Mutual Information**:
$$I(X; Y) = H(X) - H(X|Y) = \sum_{x,y} p(x,y) \log \frac{p(x,y)}{p(x)p(y)}$$

**KL Divergence**:
$$D_{KL}(P||Q) = \sum_{i} p_i \log \frac{p_i}{q_i}$$

### 🎯 Application to MoL:
- **Z**: Latent representations (queries/items)
- **G**: Gating decisions (component assignments)
- **Goal**: Minimize I(Z; G) → components independent of input patterns

In [None]:
class InformationTheoryUtils:
    """
    Utility functions for Information Theory calculations
    """
    
    @staticmethod
    def entropy(probabilities: np.ndarray, base: float = 2) -> float:
        """
        Calculate Shannon entropy
        """
        p = probabilities + 1e-12  # Avoid log(0)
        return -np.sum(p * np.log(p) / np.log(base))
    
    @staticmethod
    def mutual_information_discrete(X: np.ndarray, Y: np.ndarray) -> float:
        """
        Calculate mutual information for discrete variables
        """
        return mutual_info_score(X, Y)
    
    @staticmethod
    def kl_divergence(P: np.ndarray, Q: np.ndarray) -> float:
        """
        Calculate KL divergence D_KL(P||Q)
        """
        P = P + 1e-12
        Q = Q + 1e-12
        return np.sum(P * np.log(P / Q))
    
    @staticmethod
    def perplexity(probabilities: np.ndarray) -> float:
        """
        Calculate perplexity (effective number of components)
        """
        entropy = InformationTheoryUtils.entropy(probabilities, base=np.e)
        return np.exp(entropy)
    
    @staticmethod
    def estimate_mutual_information_continuous(X: torch.Tensor, Y: torch.Tensor, 
                                             bins: int = 20) -> float:
        """
        Estimate mutual information for continuous variables using binning
        """
        # Convert to numpy
        X_np = X.detach().cpu().numpy().flatten()
        Y_np = Y.detach().cpu().numpy().flatten()
        
        # Discretize using histograms
        X_discrete = np.digitize(X_np, np.histogram(X_np, bins=bins)[1][:-1])
        Y_discrete = np.digitize(Y_np, np.histogram(Y_np, bins=bins)[1][:-1])
        
        return mutual_info_score(X_discrete, Y_discrete)

# Demonstrate information theory concepts
print("📊 Information Theory Demonstrations")
print("=" * 50)

utils = InformationTheoryUtils()

# Entropy examples
distributions = {
    'uniform': np.ones(8) / 8,
    'peaked': np.array([0.7, 0.1, 0.05, 0.05, 0.03, 0.03, 0.02, 0.02]),
    'bimodal': np.array([0.4, 0.4, 0.05, 0.05, 0.05, 0.05, 0, 0]),
    'delta': np.array([1.0, 0, 0, 0, 0, 0, 0, 0])
}

print("\n🔢 Entropy Analysis:")
for name, dist in distributions.items():
    entropy = utils.entropy(dist)
    perplexity = utils.perplexity(dist)
    print(f"   {name:8s}: H = {entropy:.3f}, Perplexity = {perplexity:.2f}")

# Mutual Information demonstration
print("\n🔗 Mutual Information Examples:")

# Independent variables
X_indep = np.random.randint(0, 4, 1000)
Y_indep = np.random.randint(0, 4, 1000)
mi_indep = utils.mutual_information_discrete(X_indep, Y_indep)
print(f"   Independent variables: I(X;Y) = {mi_indep:.3f}")

# Dependent variables
X_dep = np.random.randint(0, 4, 1000)
Y_dep = (X_dep + np.random.randint(0, 2, 1000)) % 4  # Partial dependence
mi_dep = utils.mutual_information_discrete(X_dep, Y_dep)
print(f"   Dependent variables: I(X;Y) = {mi_dep:.3f}")

# Perfectly correlated
X_corr = np.random.randint(0, 4, 1000)
Y_corr = X_corr.copy()  # Perfect correlation
mi_corr = utils.mutual_information_discrete(X_corr, Y_corr)
print(f"   Perfectly correlated: I(X;Y) = {mi_corr:.3f}")

## ⚖️ Phần 3: Load Balancing Loss Implementation

### 🎯 Mục tiêu:
Thiết kế loss function để:
1. **Encourage uniform component usage**
2. **Minimize mutual information** between inputs và gating decisions
3. **Prevent component collapse**
4. **Maintain model expressiveness**

In [None]:
class AdvancedLoadBalancingLoss(nn.Module):
    """
    Advanced Load Balancing Loss with multiple components:
    1. Uniform Distribution Loss (KL divergence)
    2. Entropy Maximization Loss
    3. Mutual Information Minimization Loss
    4. Variance Minimization Loss
    """
    
    def __init__(self, 
                 num_components: int,
                 lambda_uniform: float = 0.01,
                 lambda_entropy: float = 0.01,
                 lambda_mi: float = 0.005,
                 lambda_variance: float = 0.01):
        super().__init__()
        self.num_components = num_components
        self.lambda_uniform = lambda_uniform
        self.lambda_entropy = lambda_entropy
        self.lambda_mi = lambda_mi
        self.lambda_variance = lambda_variance
        
        # Target uniform distribution
        self.register_buffer('uniform_target', 
                           torch.ones(num_components) / num_components)
    
    def forward(self, 
               gating_weights: torch.Tensor,
               input_features: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """
        Compute comprehensive load balancing loss
        
        Args:
            gating_weights: [batch_size, num_components] - gating decisions
            input_features: [batch_size, feature_dim] - input representations
        
        Returns:
            Dictionary of loss components
        """
        batch_size = gating_weights.size(0)
        
        # Reshape if needed
        if gating_weights.dim() > 2:
            gating_weights = gating_weights.view(batch_size, -1)
        
        # Component usage statistics
        component_usage = gating_weights.mean(dim=0)  # [num_components]
        
        losses = {}
        
        # 1. Uniform Distribution Loss (KL Divergence)
        uniform_loss = F.kl_div(
            (component_usage + 1e-8).log(),
            self.uniform_target,
            reduction='sum'
        )
        losses['uniform'] = self.lambda_uniform * uniform_loss
        
        # 2. Entropy Maximization Loss
        entropy = -torch.sum(component_usage * torch.log(component_usage + 1e-8))
        max_entropy = math.log(self.num_components)  # Maximum possible entropy
        entropy_loss = max_entropy - entropy  # Minimize negative entropy
        losses['entropy'] = self.lambda_entropy * entropy_loss
        
        # 3. Variance Minimization Loss (encourage equal usage)
        variance_loss = torch.var(component_usage)
        losses['variance'] = self.lambda_variance * variance_loss
        
        # 4. Mutual Information Minimization (if input features available)
        if input_features is not None:
            mi_loss = self._approximate_mutual_information_loss(
                input_features, gating_weights
            )
            losses['mutual_info'] = self.lambda_mi * mi_loss
        else:
            losses['mutual_info'] = torch.tensor(0.0, device=gating_weights.device)
        
        # Total loss
        losses['total'] = sum(losses.values())
        
        return losses
    
    def _approximate_mutual_information_loss(self, 
                                           inputs: torch.Tensor, 
                                           gating_weights: torch.Tensor) -> torch.Tensor:
        """
        Approximate mutual information loss using neural estimation
        
        This is a simplified version - in practice, you might use:
        - MINE (Mutual Information Neural Estimation)
        - InfoNCE
        - Histogram-based estimation
        """
        batch_size = inputs.size(0)
        
        # Simple approximation: encourage independence
        # by minimizing correlation between input patterns and gating decisions
        
        # Compute pairwise similarities in input space
        input_norm = F.normalize(inputs, dim=1)
        input_sim = torch.mm(input_norm, input_norm.t())  # [batch, batch]
        
        # Compute pairwise similarities in gating space
        gating_norm = F.normalize(gating_weights, dim=1)
        gating_sim = torch.mm(gating_norm, gating_norm.t())  # [batch, batch]
        
        # Minimize correlation between input and gating similarities
        # This encourages gating decisions to be independent of input patterns
        correlation = torch.mean(input_sim * gating_sim)
        
        return correlation

# Demonstration
print("⚖️ Advanced Load Balancing Loss Demo")
print("=" * 50)

num_components = 8
batch_size = 100
feature_dim = 64

# Create load balancing loss
lb_loss = AdvancedLoadBalancingLoss(num_components)

# Test different gating weight scenarios
scenarios = {
    'uniform': torch.ones(batch_size, num_components) / num_components,
    'biased': F.softmax(torch.tensor([3., 2., 1., 1., 0.5, 0.5, 0.2, 0.2]).unsqueeze(0).expand(batch_size, -1), dim=1),
    'collapsed': F.softmax(torch.tensor([10., 1., 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]).unsqueeze(0).expand(batch_size, -1), dim=1)
}

input_features = torch.randn(batch_size, feature_dim)

print("\n📊 Load Balancing Loss Analysis:")
for scenario_name, weights in scenarios.items():
    with torch.no_grad():
        losses = lb_loss(weights, input_features)
        
        component_usage = weights.mean(dim=0)
        entropy = -torch.sum(component_usage * torch.log(component_usage + 1e-8))
        
        print(f"\n   {scenario_name.upper()}:")
        print(f"     Total Loss: {losses['total'].item():.4f}")
        print(f"     Uniform Loss: {losses['uniform'].item():.4f}")
        print(f"     Entropy Loss: {losses['entropy'].item():.4f}")
        print(f"     Variance Loss: {losses['variance'].item():.4f}")
        print(f"     MI Loss: {losses['mutual_info'].item():.4f}")
        print(f"     Component Entropy: {entropy.item():.3f} / {math.log(num_components):.3f}")
        print(f"     Usage: {component_usage.numpy()}")

## 🧪 Phần 4: Training với Load Balancing

### 🎯 Experiment Design:
So sánh training với và không có load balancing loss để thấy impact:

In [None]:
class SimpleMoLWithLoadBalancing(nn.Module):
    """
    Simplified MoL model with integrated load balancing
    """
    
    def __init__(self, input_dim: int, num_components: int = 6, component_dim: int = 32):
        super().__init__()
        self.num_components = num_components
        self.component_dim = component_dim
        
        # Component embeddings
        self.query_embeddings = nn.ModuleList([
            nn.Linear(input_dim, component_dim) for _ in range(num_components)
        ])
        self.item_embeddings = nn.ModuleList([
            nn.Linear(input_dim, component_dim) for _ in range(num_components)
        ])
        
        # Gating network
        self.gating_network = nn.Sequential(
            nn.Linear(input_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, num_components),
            nn.Softmax(dim=-1)
        )
        
        # Load balancing
        self.load_balancer = AdvancedLoadBalancingLoss(num_components)
    
    def forward(self, queries: torch.Tensor, items: torch.Tensor, 
               return_gating_info: bool = False) -> Dict:
        """
        Forward pass with optional gating information
        """
        batch_q, batch_i = queries.size(0), items.size(0)
        
        similarities = torch.zeros(batch_q, batch_i, device=queries.device)
        all_gating_weights = []
        all_input_features = []
        
        # Compute component embeddings
        q_components = [F.normalize(emb(queries), dim=-1) for emb in self.query_embeddings]
        i_components = [F.normalize(emb(items), dim=-1) for emb in self.item_embeddings]
        
        # Compute similarities for all pairs
        for i in range(batch_q):
            for j in range(batch_i):
                # Combined features for gating
                combined_features = torch.cat([queries[i], items[j]], dim=0)
                gating_weights = self.gating_network(combined_features.unsqueeze(0)).squeeze(0)
                
                if return_gating_info:
                    all_gating_weights.append(gating_weights)
                    all_input_features.append(combined_features)
                
                # Compute weighted similarity
                similarity = 0.0
                for p in range(self.num_components):
                    component_sim = torch.dot(q_components[p][i], i_components[p][j])
                    similarity += gating_weights[p] * component_sim
                
                similarities[i, j] = similarity
        
        result = {'similarities': similarities}
        
        if return_gating_info and all_gating_weights:
            result['gating_weights'] = torch.stack(all_gating_weights)
            result['input_features'] = torch.stack(all_input_features)
        
        return result
    
    def compute_load_balancing_loss(self, gating_weights: torch.Tensor, 
                                  input_features: torch.Tensor) -> Dict:
        """
        Compute load balancing loss
        """
        return self.load_balancer(gating_weights, input_features)

def train_with_load_balancing_comparison(num_epochs: int = 100):
    """
    Compare training with and without load balancing
    """
    print("\n🧪 Load Balancing Training Comparison")
    print("=" * 50)
    
    # Generate synthetic data
    num_queries, num_items = 30, 50
    input_dim = 32
    
    queries = torch.randn(num_queries, input_dim)
    items = torch.randn(num_items, input_dim)
    
    # Create relevance labels (simplified)
    labels = torch.sigmoid(torch.mm(queries, items.t()) + 0.5 * torch.randn(num_queries, num_items))
    
    # Two models: with and without load balancing
    model_with_lb = SimpleMoLWithLoadBalancing(input_dim, num_components=6)
    model_without_lb = SimpleMoLWithLoadBalancing(input_dim, num_components=6)
    
    # Optimizers
    opt_with_lb = torch.optim.Adam(model_with_lb.parameters(), lr=0.01)
    opt_without_lb = torch.optim.Adam(model_without_lb.parameters(), lr=0.01)
    
    # Training history
    history = {
        'with_lb': {'main_loss': [], 'lb_loss': [], 'entropy': [], 'gini': []},
        'without_lb': {'main_loss': [], 'lb_loss': [], 'entropy': [], 'gini': []}
    }
    
    for epoch in range(num_epochs):
        # Train model WITH load balancing
        opt_with_lb.zero_grad()
        
        output_with_lb = model_with_lb(queries, items, return_gating_info=True)
        similarities_with_lb = output_with_lb['similarities']
        
        main_loss_with_lb = F.mse_loss(similarities_with_lb, labels)
        
        if 'gating_weights' in output_with_lb:
            lb_losses = model_with_lb.compute_load_balancing_loss(
                output_with_lb['gating_weights'], 
                output_with_lb['input_features']
            )
            total_loss_with_lb = main_loss_with_lb + lb_losses['total']
        else:
            lb_losses = {'total': torch.tensor(0.0)}
            total_loss_with_lb = main_loss_with_lb
        
        total_loss_with_lb.backward()
        opt_with_lb.step()
        
        # Train model WITHOUT load balancing
        opt_without_lb.zero_grad()
        
        output_without_lb = model_without_lb(queries, items, return_gating_info=True)
        similarities_without_lb = output_without_lb['similarities']
        
        main_loss_without_lb = F.mse_loss(similarities_without_lb, labels)
        main_loss_without_lb.backward()
        opt_without_lb.step()
        
        # Compute statistics
        with torch.no_grad():
            for model_name, output in [('with_lb', output_with_lb), ('without_lb', output_without_lb)]:
                if 'gating_weights' in output:
                    gating_weights = output['gating_weights']
                    component_usage = gating_weights.mean(dim=0)
                    
                    # Compute entropy
                    entropy = -torch.sum(component_usage * torch.log(component_usage + 1e-8))
                    
                    # Compute Gini coefficient
                    sorted_weights = torch.sort(component_usage)[0]
                    n = len(sorted_weights)
                    index = torch.arange(1, n + 1, dtype=torch.float32)
                    gini = (2 * torch.sum(index * sorted_weights)) / (n * torch.sum(sorted_weights)) - (n + 1) / n
                    
                    history[model_name]['entropy'].append(entropy.item())
                    history[model_name]['gini'].append(gini.item())
                else:
                    history[model_name]['entropy'].append(0)
                    history[model_name]['gini'].append(1.0)
        
        # Record losses
        history['with_lb']['main_loss'].append(main_loss_with_lb.item())
        history['with_lb']['lb_loss'].append(lb_losses['total'].item())
        history['without_lb']['main_loss'].append(main_loss_without_lb.item())
        history['without_lb']['lb_loss'].append(0.0)
        
        if epoch % 20 == 0:
            print(f"Epoch {epoch:3d}: WITH LB - Main: {main_loss_with_lb.item():.4f}, LB: {lb_losses['total'].item():.4f}")
            print(f"           WITHOUT LB - Main: {main_loss_without_lb.item():.4f}")
    
    return history

# Run training comparison
training_history = train_with_load_balancing_comparison(num_epochs=80)

## 📈 Phần 5: Visualization và Analysis

In [None]:
# Create comprehensive visualizations
fig, axes = plt.subplots(3, 3, figsize=(18, 15))

# 1. Component Collapse Scenarios
scenario_names = list(collapse_results.keys())
perplexities = [collapse_results[name]['perplexity'] for name in scenario_names]
ginis = [collapse_results[name]['gini'] for name in scenario_names]

axes[0, 0].bar(scenario_names, perplexities, alpha=0.7, color='skyblue')
axes[0, 0].axhline(y=8, color='red', linestyle='--', alpha=0.7, label='Max Perplexity')
axes[0, 0].set_title('Component Collapse: Effective Components')
axes[0, 0].set_ylabel('Perplexity (Effective Components)')
axes[0, 0].legend()
axes[0, 0].tick_params(axis='x', rotation=45)

# 2. Gini Coefficient
colors = ['green', 'yellow', 'orange', 'red']
bars = axes[0, 1].bar(scenario_names, ginis, alpha=0.7, color=colors)
axes[0, 1].set_title('Component Collapse: Inequality (Gini)')
axes[0, 1].set_ylabel('Gini Coefficient')
axes[0, 1].set_ylim(0, 1)
axes[0, 1].tick_params(axis='x', rotation=45)

# Add text annotations
for bar, gini in zip(bars, ginis):
    height = bar.get_height()
    axes[0, 1].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                   f'{gini:.3f}', ha='center', va='bottom')

# 3. Weight Distributions
for i, (name, weights) in enumerate(collapse_scenarios.items()):
    if i < 4:  # Show first 4 scenarios
        mean_weights = np.mean(weights, axis=0)
        axes[0, 2].plot(mean_weights, 'o-', label=name, alpha=0.8, linewidth=2)

axes[0, 2].set_title('Component Weight Distributions')
axes[0, 2].set_xlabel('Component Index')
axes[0, 2].set_ylabel('Average Weight')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# 4. Training Loss Comparison
epochs = range(len(training_history['with_lb']['main_loss']))
axes[1, 0].plot(epochs, training_history['with_lb']['main_loss'], 'b-', label='With Load Balancing', linewidth=2)
axes[1, 0].plot(epochs, training_history['without_lb']['main_loss'], 'r-', label='Without Load Balancing', linewidth=2)
axes[1, 0].set_title('Training: Main Loss Comparison')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Main Loss (MSE)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 5. Load Balancing Loss
axes[1, 1].plot(epochs, training_history['with_lb']['lb_loss'], 'g-', linewidth=2)
axes[1, 1].set_title('Load Balancing Loss Over Time')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Load Balancing Loss')
axes[1, 1].grid(True, alpha=0.3)

# 6. Entropy Evolution
max_entropy = math.log(6)  # For 6 components
axes[1, 2].plot(epochs, training_history['with_lb']['entropy'], 'b-', label='With LB', linewidth=2)
axes[1, 2].plot(epochs, training_history['without_lb']['entropy'], 'r-', label='Without LB', linewidth=2)
axes[1, 2].axhline(y=max_entropy, color='green', linestyle='--', alpha=0.7, label='Max Entropy')
axes[1, 2].set_title('Component Usage Entropy')
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Entropy')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

# 7. Gini Evolution
axes[2, 0].plot(epochs, training_history['with_lb']['gini'], 'b-', label='With LB', linewidth=2)
axes[2, 0].plot(epochs, training_history['without_lb']['gini'], 'r-', label='Without LB', linewidth=2)
axes[2, 0].set_title('Component Usage Inequality (Gini)')
axes[2, 0].set_xlabel('Epoch')
axes[2, 0].set_ylabel('Gini Coefficient')
axes[2, 0].legend()
axes[2, 0].grid(True, alpha=0.3)

# 8. Information Theory Concepts Visualization
entropy_values = []
perplexity_values = []
prob_values = np.linspace(0.1, 1.0, 10)

for p in prob_values:
    # Binary distribution
    dist = np.array([p, 1-p])
    entropy = utils.entropy(dist, base=2)
    perplexity = utils.perplexity(dist)
    entropy_values.append(entropy)
    perplexity_values.append(perplexity)

axes[2, 1].plot(prob_values, entropy_values, 'go-', label='Entropy', linewidth=2)
axes[2, 1].set_title('Binary Distribution: Entropy vs Probability')
axes[2, 1].set_xlabel('P(X=1)')
axes[2, 1].set_ylabel('Entropy (bits)')
axes[2, 1].grid(True, alpha=0.3)

# 9. Perplexity vs Probability
axes[2, 2].plot(prob_values, perplexity_values, 'mo-', linewidth=2)
axes[2, 2].set_title('Binary Distribution: Perplexity vs Probability')
axes[2, 2].set_xlabel('P(X=1)')
axes[2, 2].set_ylabel('Perplexity')
axes[2, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n📊 Comprehensive visualization completed")

## 🎓 Key Insights và Practical Guidelines

### 🔍 Quan sát từ Experiments:

1. **Component Collapse Impact**:
   - Uniform distribution: Perplexity = 8/8 (full utilization)
   - Collapsed distribution: Perplexity << 8 (wasted capacity)
   - Gini coefficient tăng → inequality tăng

2. **Load Balancing Benefits**:
   - Higher entropy → better component utilization
   - Lower Gini coefficient → more equal distribution
   - Potentially better generalization

3. **Training Dynamics**:
   - Load balancing loss initially high, decreases over time
   - Main task performance không bị hurt significantly
   - Component usage becomes more balanced

### 📖 Theoretical Insights:

**Information Theory Perspective**:
- **High Entropy** → Uniform component usage → Better capacity utilization
- **Low Mutual Information** → Gating decisions independent of input patterns
- **Perplexity** → Effective number of components being used

**Mathematical Foundation**:
```
Optimal Load Balancing:
- Maximize H(G) (entropy of gating decisions)
- Minimize I(Z; G) (mutual information with inputs)
- Minimize Var(usage) (variance in component usage)
```

### 🚀 Practical Implementation Guidelines:

1. **Loss Function Design**:
   ```python
   total_loss = main_loss + λ₁ * uniform_loss + λ₂ * entropy_loss + λ₃ * mi_loss
   ```

2. **Hyperparameter Tuning**:
   - Start with small λ values (0.001 - 0.01)
   - Monitor component usage statistics
   - Adjust based on Gini coefficient and perplexity

3. **Monitoring Metrics**:
   - **Perplexity**: Should be close to num_components
   - **Gini Coefficient**: Should be close to 0
   - **Entropy**: Should be close to log(num_components)

4. **When to Use Load Balancing**:
   - Large number of components (P > 4)
   - Complex datasets with diverse patterns
   - When component collapse is observed
   - In production systems where efficiency matters

### ⚠️ Common Pitfalls:

1. **Over-regularization**: Too strong load balancing → worse main task performance
2. **Under-regularization**: Component collapse → wasted model capacity
3. **Ignoring MI term**: Components may still correlate with input patterns
4. **Fixed λ values**: May need adaptive scheduling during training

### 🎯 Advanced Techniques:

1. **Adaptive Load Balancing**: Adjust λ based on current component usage
2. **Curriculum Learning**: Gradually increase load balancing strength
3. **Temperature Annealing**: Start with high temperature, decrease over time
4. **Component-specific Penalties**: Different penalties for different components

### 📚 Research Directions:

1. **Better MI Estimation**: MINE, InfoNCE, contrastive methods
2. **Dynamic Component Count**: Adaptive number of components
3. **Hierarchical Load Balancing**: Multi-level component organization
4. **Hardware-aware Load Balancing**: Consider GPU memory and compute constraints