In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

class SparseRouterMoE(nn.Module):
    """
    Sparse router for Mixture of Experts that decides which expert to use for each input.
    Uses a k-sparse activation where only top-k experts are selected.
    """
    def __init__(self, input_size, num_experts, k=1):
        super().__init__()
        self.input_size = input_size
        self.num_experts = num_experts
        self.k = k  # Number of experts to route to
        
        # Router network (maps input to expert selection probabilities)
        self.router = nn.Linear(input_size, num_experts)
    
    def forward(self, x):
        """
        x: Input embedding or hidden state
        Returns:
          - dispatch_tensor: Binary tensor indicating which experts to use
          - combine_tensor: Weights for combining expert outputs
          - router_logits: Raw logits from router
        """
        # Get router logits and probabilities
        router_logits = self.router(x)  # [batch_size, seq_len, num_experts]
        
        # Apply softmax to get expert selection probabilities
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(router_probs, self.k, dim=-1)
        
        # Normalize the top-k probabilities
        top_k_probs_sum = torch.sum(top_k_probs, dim=-1, keepdim=True)
        top_k_probs = top_k_probs / top_k_probs_sum
        
        # Create dispatch tensor (indicates which experts to use)
        dispatch_tensor = torch.zeros_like(router_probs)
        
        # Use scatter to place top-k probabilities in the dispatch tensor
        dispatch_tensor.scatter_(-1, top_k_indices, top_k_probs)
        
        return dispatch_tensor, top_k_probs, top_k_indices, router_logits

class MoEInput:
    """Helper class to store input representation for routing"""
    def __init__(self, input_ids, attention_mask=None):
        self.input_ids = input_ids
        self.attention_mask = attention_mask

class QwenMoEEnsemble(nn.Module):
    """
    Ensemble model using Sparse Mixture of Experts architecture with 
    Qwen models as experts.
    """
    def __init__(self, expert_models, device="cuda", router_dim=768, k=1):
        super().__init__()
        self.device = device
        self.num_experts = len(expert_models)
        self.k = k
        
        # Load expert models and tokenizers
        self.experts = []
        self.tokenizers = []
        
        print(f"Loading {self.num_experts} expert models...")
        for model_info in expert_models:
            model_name = model_info["source_model"]
            print(f"Loading {model_name}...")
            
            model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                torch_dtype=torch.float16,
                trust_remote_code=True
            ).to(device)
            
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True
            )
            
            self.experts.append(model)
            self.tokenizers.append(tokenizer)
            
        # Create embedding layer for router (using embeddings from first model)
        self.embedding = self.experts[0].get_input_embeddings()
        
        # Create router
        self.router = SparseRouterMoE(router_dim, self.num_experts, k=k)
        
    def get_router_input(self, input_ids):
        """
        Extract features for routing decision using embeddings
        """
        # Get embeddings using the first model's embedding layer
        emb = self.embedding(input_ids)
        
        # Use mean of sequence embeddings for routing decision
        pooled = torch.mean(emb, dim=1)
        
        return pooled
    
    def forward(self, input_ids, attention_mask=None, return_expert_weights=False):
        """
        Forward pass through MoE ensemble
        """
        batch_size = input_ids.shape[0]
        
        # Get router input representation
        router_input = self.get_router_input(input_ids)
        
        # Get expert selection from router
        dispatch_tensor, top_k_probs, top_k_indices, router_logits = self.router(router_input)
        
        # Prepare container for combined output
        expert_outputs = []
        
        # Get outputs from selected experts
        for batch_idx in range(batch_size):
            # Get the experts selected for this input
            selected_experts = top_k_indices[batch_idx]
            expert_weights = top_k_probs[batch_idx]
            
            # Get outputs from selected experts
            batch_input_ids = input_ids[batch_idx:batch_idx+1]
            batch_attention_mask = None
            if attention_mask is not None:
                batch_attention_mask = attention_mask[batch_idx:batch_idx+1]
            
            # Calculate weighted sum of expert outputs
            batch_outputs = []
            for expert_idx, weight in zip(selected_experts, expert_weights):
                expert = self.experts[expert_idx]
                
                # Get output from expert
                with torch.no_grad():
                    expert_output = expert(
                        input_ids=batch_input_ids,
                        attention_mask=batch_attention_mask
                    )
                    
                # Scale output by expert weight
                weighted_output = expert_output.logits * weight
                batch_outputs.append(weighted_output)
            
            # Combine outputs
            combined_output = torch.sum(torch.stack(batch_outputs), dim=0)
            expert_outputs.append(combined_output)
        
        # Stack outputs from all batches
        combined_logits = torch.cat(expert_outputs, dim=0)
        
        if return_expert_weights:
            return combined_logits, dispatch_tensor
        return combined_logits

    def generate(self, input_text, max_length=512, temperature=0.7, top_p=0.9, 
                 num_return_sequences=1, return_expert_info=False):
        """
        Generate text using the MoE ensemble
        """
        # Tokenize input text (using first tokenizer as default)
        tokenizer = self.tokenizers[0]
        inputs = tokenizer(input_text, return_tensors="pt").to(self.device)
        input_ids = inputs.input_ids
        
        # Get router input and determine which experts to use
        router_input = self.get_router_input(input_ids)
        dispatch_tensor, top_k_probs, top_k_indices, _ = self.router(router_input)
        
        # For simplicity, use the top expert for generation
        expert_idx = top_k_indices[0][0].item()
        weight = top_k_probs[0][0].item()
        expert = self.experts[expert_idx]
        expert_tokenizer = self.tokenizers[expert_idx]
        
        print(f"Using expert {expert_idx} with weight {weight:.4f}")
        
        # Generate text with the selected expert
        with torch.no_grad():
            outputs = expert.generate(
                input_ids,
                max_length=max_length,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=num_return_sequences,
            )
        
        # Decode generated text
        generated_texts = expert_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        if return_expert_info:
            expert_info = {
                "expert_idx": expert_idx,
                "expert_weight": weight,
                "expert_model": expert.__class__.__name__,
                "dispatch_tensor": dispatch_tensor.detach().cpu().numpy(),
            }
            return generated_texts, expert_info
        
        return generated_texts

# Helper function to save and load the MoE ensemble
def save_moe_ensemble(model, path):
    """Save MoE ensemble model"""
    torch.save({
        'router_state_dict': model.router.state_dict(),
        'model_config': {
            'num_experts': model.num_experts,
            'k': model.k,
        }
    }, path)
    
def load_moe_ensemble(path, expert_models, device="cuda", router_dim=768):
    """Load MoE ensemble model"""
    checkpoint = torch.load(path)
    
    # Create new ensemble
    model = QwenMoEEnsemble(
        expert_models=expert_models,
        device=device,
        router_dim=router_dim,
        k=checkpoint['model_config']['k']
    )
    
    # Load router state
    model.router.load_state_dict(checkpoint['router_state_dict'])
    
    return model

# Training function for fine-tuning the router
def train_moe_router(model, train_dataset, optimizer, num_epochs=3, batch_size=8):
    """
    Train the router of the MoE ensemble
    
    Args:
        model: The MoE ensemble model
        train_dataset: Dataset of (input_text, expert_label) pairs
        optimizer: Optimizer for the router
        num_epochs: Number of training epochs
        batch_size: Batch size for training
    """
    model.train()
    
    # Place only the router parameters in training mode
    for expert in model.experts:
        expert.eval()
    
    # Create data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_loader:
            # Unpack batch
            input_texts, expert_labels = batch
            
            # Tokenize input texts
            tokenizer = model.tokenizers[0]
            inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device)
            input_ids = inputs.input_ids
            
            # Get router input
            router_input = model.get_router_input(input_ids)
            
            # Get router logits
            _, _, _, router_logits = model.router(router_input)
            
            # Compute cross-entropy loss for router
            expert_labels = expert_labels.to(model.device)
            loss = F.cross_entropy(router_logits, expert_labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

# Example usage
if __name__ == "__main__":
    # Define expert models
    expert_models = [
        {"source_model": "Qwen/Qwen2.5-1.5B-Instruct"},
        {"source_model": "Qwen/Qwen2.5-Coder-1.5B-Instruct"},
        {"source_model": "Qwen/Qwen2.5-Math-1.5B-Instruct"}
    ]
    
    # Create MoE ensemble
    device = "cuda" if torch.cuda.is_available() else "cpu"
    moe_model = QwenMoEEnsemble(expert_models, device=device, k=1)
    
    # # Example input for testing
    # input_text = "Write a function to find the prime numbers up to n"
    
    # # Generate text with the MoE ensemble
    # generated_texts, expert_info = moe_model.generate(
    #     input_text, max_length=512, return_expert_info=True
    # )
    
    # print(f"Input: {input_text}")
    # print(f"Using expert: {expert_info['expert_idx']} (weight: {expert_info['expert_weight']:.4f})")
    # print(f"Generated text: {generated_texts[0]}")
    
    # # Example of how to create a training dataset
    # class ExampleDataset(torch.utils.data.Dataset):
    #     def __init__(self, examples):
    #         self.examples = examples
            
    #     def __len__(self):
    #         return len(self.examples)
            
    #     def __getitem__(self, idx):
    #         return self.examples[idx]
    
    # # Example training data (input_text, expert_label)
    # training_examples = [
    #     ("Write a function to calculate fibonacci numbers", 1),  # Programming → Coder model
    #     ("Solve the equation 3x + 5 = 20", 2),                   # Math → Math model
    #     ("Tell me about climate change", 0),                     # General → Instruct model
    # ]
    
    # # Create training dataset
    # train_dataset = ExampleDataset(training_examples)
    
    # # Create optimizer for router parameters
    # optimizer = torch.optim.Adam(moe_model.router.parameters(), lr=1e-4)
    
    # # Train router
    # # Uncomment to train:
    # # train_moe_router(moe_model, train_dataset, optimizer, num_epochs=5)
    
    # # Save the trained ensemble
    # # save_moe_ensemble(moe_model, "qwen_moe_ensemble.pt")