In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

class SparseRouterMoE(nn.Module):
    """
    Sparse router for Mixture of Experts that decides which expert to use for each input.
    Uses a k-sparse activation where only top-k experts are selected.
    """
    def __init__(self, input_size, num_experts, k=1):
        super().__init__()
        self.input_size = input_size
        self.num_experts = num_experts
        self.k = k  # Number of experts to route to
        
        # Router network (maps input to expert selection probabilities)
        self.router = nn.Linear(input_size, num_experts)
    
    def forward(self, x):
        """
        x: Input embedding or hidden state
        Returns:
          - dispatch_tensor: Binary tensor indicating which experts to use
          - combine_tensor: Weights for combining expert outputs
          - router_logits: Raw logits from router
        """
        # Get router logits and probabilities
        router_logits = self.router(x)  # [batch_size, seq_len, num_experts]
        
        # Apply softmax to get expert selection probabilities
        router_probs = F.softmax(router_logits, dim=-1)
        
        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(router_probs, self.k, dim=-1)
        
        # Normalize the top-k probabilities
        top_k_probs_sum = torch.sum(top_k_probs, dim=-1, keepdim=True)
        top_k_probs = top_k_probs / top_k_probs_sum
        
        # Create dispatch tensor (indicates which experts to use)
        dispatch_tensor = torch.zeros_like(router_probs)
        
        # Use scatter to place top-k probabilities in the dispatch tensor
        dispatch_tensor.scatter_(-1, top_k_indices, top_k_probs)
        
        return dispatch_tensor, top_k_probs, top_k_indices, router_logits

class MoEInput:
    """Helper class to store input representation for routing"""
    def __init__(self, input_ids, attention_mask=None):
        self.input_ids = input_ids
        self.attention_mask = attention_mask

class QwenMoEEnsemble(nn.Module):
    """
    Ensemble model using Sparse Mixture of Experts architecture with 
    Qwen models as experts.
    """
    def __init__(self, expert_models, device="cuda", router_dim=768, k=1):
        super().__init__()
        self.device = device
        self.num_experts = len(expert_models)
        self.k = k
        
        # Load expert models and tokenizers
        self.experts = []
        self.tokenizers = []
        
        print(f"Loading {self.num_experts} expert models...")
        for model_info in expert_models:
            model_name = model_info["source_model"]
            print(f"Loading {model_name}...")
            
            model = AutoModelForCausalLM.from_pretrained(
                model_name, 
                torch_dtype=torch.float16,
                trust_remote_code=True
            ).to(device)
            
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True
            )
            
            self.experts.append(model)
            self.tokenizers.append(tokenizer)
            
        # Create embedding layer for router (using embeddings from first model)
        self.embedding = self.experts[0].get_input_embeddings()
        
        # Create router
        self.router = SparseRouterMoE(router_dim, self.num_experts, k=k)
        
    def get_router_input(self, input_ids):
        """
        Extract features for routing decision using embeddings
        """
        # Get embeddings using the first model's embedding layer
        emb = self.embedding(input_ids)
        
        # Use mean of sequence embeddings for routing decision
        pooled = torch.mean(emb, dim=1)
        
        return pooled
    
    def forward(self, input_ids, attention_mask=None, return_expert_weights=False):
        """
        Forward pass through MoE ensemble
        """
        batch_size = input_ids.shape[0]
        
        # Get router input representation
        router_input = self.get_router_input(input_ids)
        
        # Get expert selection from router
        dispatch_tensor, top_k_probs, top_k_indices, router_logits = self.router(router_input)
        
        # Prepare container for combined output
        expert_outputs = []
        
        # Get outputs from selected experts
        for batch_idx in range(batch_size):
            # Get the experts selected for this input
            selected_experts = top_k_indices[batch_idx]
            expert_weights = top_k_probs[batch_idx]
            
            # Get outputs from selected experts
            batch_input_ids = input_ids[batch_idx:batch_idx+1]
            batch_attention_mask = None
            if attention_mask is not None:
                batch_attention_mask = attention_mask[batch_idx:batch_idx+1]
            
            # Calculate weighted sum of expert outputs
            batch_outputs = []
            for expert_idx, weight in zip(selected_experts, expert_weights):
                expert = self.experts[expert_idx]
                
                # Get output from expert
                with torch.no_grad():
                    expert_output = expert(
                        input_ids=batch_input_ids,
                        attention_mask=batch_attention_mask
                    )
                    
                # Scale output by expert weight
                weighted_output = expert_output.logits * weight
                batch_outputs.append(weighted_output)
            
            # Combine outputs
            combined_output = torch.sum(torch.stack(batch_outputs), dim=0)
            expert_outputs.append(combined_output)
        
        # Stack outputs from all batches
        combined_logits = torch.cat(expert_outputs, dim=0)
        
        if return_expert_weights:
            return combined_logits, dispatch_tensor
        return combined_logits

    def generate(self, input_text, max_length=512, temperature=0.7, top_p=0.9, 
                 num_return_sequences=1, return_expert_info=False):
        """
        Generate text using the MoE ensemble
        """
        # Tokenize input text (using first tokenizer as default)
        tokenizer = self.tokenizers[0]
        inputs = tokenizer(input_text, return_tensors="pt").to(self.device)
        input_ids = inputs.input_ids
        
        # Get router input and determine which experts to use
        router_input = self.get_router_input(input_ids)
        dispatch_tensor, top_k_probs, top_k_indices, _ = self.router(router_input)
        
        # For simplicity, use the top expert for generation
        expert_idx = top_k_indices[0][0].item()
        weight = top_k_probs[0][0].item()
        expert = self.experts[expert_idx]
        expert_tokenizer = self.tokenizers[expert_idx]
        
        print(f"Using expert {expert_idx} with weight {weight:.4f}")
        
        # Generate text with the selected expert
        with torch.no_grad():
            outputs = expert.generate(
                input_ids,
                max_length=max_length,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=num_return_sequences,
            )
        
        # Decode generated text
        generated_texts = expert_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        if return_expert_info:
            expert_info = {
                "expert_idx": expert_idx,
                "expert_weight": weight,
                "expert_model": expert.__class__.__name__,
                "dispatch_tensor": dispatch_tensor.detach().cpu().numpy(),
            }
            return generated_texts, expert_info
        
        return generated_texts

# Helper function to save and load the MoE ensemble
def save_moe_ensemble(model, path):
    """Save MoE ensemble model"""
    torch.save({
        'router_state_dict': model.router.state_dict(),
        'model_config': {
            'num_experts': model.num_experts,
            'k': model.k,
        }
    }, path)
    
def load_moe_ensemble(path, expert_models, device="cuda", router_dim=768):
    """Load MoE ensemble model"""
    checkpoint = torch.load(path)
    
    # Create new ensemble
    model = QwenMoEEnsemble(
        expert_models=expert_models,
        device=device,
        router_dim=router_dim,
        k=checkpoint['model_config']['k']
    )
    
    # Load router state
    model.router.load_state_dict(checkpoint['router_state_dict'])
    
    return model

# Training function for fine-tuning the router
def train_moe_router(model, train_dataset, optimizer, num_epochs=3, batch_size=8):
    """
    Train the router of the MoE ensemble
    
    Args:
        model: The MoE ensemble model
        train_dataset: Dataset of (input_text, expert_label) pairs
        optimizer: Optimizer for the router
        num_epochs: Number of training epochs
        batch_size: Batch size for training
    """
    model.train()
    
    # Place only the router parameters in training mode
    for expert in model.experts:
        expert.eval()
    
    # Create data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_loader:
            # Unpack batch
            input_texts, expert_labels = batch
            
            # Tokenize input texts
            tokenizer = model.tokenizers[0]
            inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device)
            input_ids = inputs.input_ids
            
            # Get router input
            router_input = model.get_router_input(input_ids)
            
            # Get router logits
            _, _, _, router_logits = model.router(router_input)
            
            # Compute cross-entropy loss for router
            expert_labels = expert_labels.to(model.device)
            loss = F.cross_entropy(router_logits, expert_labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

# Example usage
if __name__ == "__main__":
    # Define expert models
    expert_models = [
        {"source_model": "Qwen/Qwen2.5-1.5B-Instruct"},
        {"source_model": "Qwen/Qwen2.5-Coder-1.5B-Instruct"},
        {"source_model": "Qwen/Qwen2.5-Math-1.5B-Instruct"}
    ]
    
    # Create MoE ensemble
    device = "cuda" if torch.cuda.is_available() else "cpu"
    moe_model = QwenMoEEnsemble(expert_models, device=device, k=1)
    
    # # Example input for testing
    # input_text = "Write a function to find the prime numbers up to n"
    
    # # Generate text with the MoE ensemble
    # generated_texts, expert_info = moe_model.generate(
    #     input_text, max_length=512, return_expert_info=True
    # )
    
    # print(f"Input: {input_text}")
    # print(f"Using expert: {expert_info['expert_idx']} (weight: {expert_info['expert_weight']:.4f})")
    # print(f"Generated text: {generated_texts[0]}")
    
    # # Example of how to create a training dataset
    # class ExampleDataset(torch.utils.data.Dataset):
    #     def __init__(self, examples):
    #         self.examples = examples
            
    #     def __len__(self):
    #         return len(self.examples)
            
    #     def __getitem__(self, idx):
    #         return self.examples[idx]
    
    # # Example training data (input_text, expert_label)
    # training_examples = [
    #     ("Write a function to calculate fibonacci numbers", 1),  # Programming → Coder model
    #     ("Solve the equation 3x + 5 = 20", 2),                   # Math → Math model
    #     ("Tell me about climate change", 0),                     # General → Instruct model
    # ]
    
    # # Create training dataset
    # train_dataset = ExampleDataset(training_examples)
    
    # # Create optimizer for router parameters
    # optimizer = torch.optim.Adam(moe_model.router.parameters(), lr=1e-4)
    
    # # Train router
    # # Uncomment to train:
    # # train_moe_router(moe_model, train_dataset, optimizer, num_epochs=5)
    
    # # Save the trained ensemble
    # # save_moe_ensemble(moe_model, "qwen_moe_ensemble.pt")

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
chat_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
code_model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
math_model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"

In [13]:
chat_tokenizer = AutoTokenizer.from_pretrained(chat_model_name, trust_remote_code=True)
code_tokenizer = AutoTokenizer.from_pretrained(code_model_name, trust_remote_code=True)
math_tokenizer = AutoTokenizer.from_pretrained(math_model_name, trust_remote_code=True)

tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/7.32k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

In [4]:
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name)
code_model = AutoModelForCausalLM.from_pretrained(code_model_name)
math_model = AutoModelForCausalLM.from_pretrained(math_model_name)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


model.safetensors:  13%|#2        | 388M/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [14]:
math_tokenizer.chat_template

'{%- if tools %}\n    {{- \'<|im_start|>system\\n\' }}\n    {%- if messages[0][\'role\'] == \'system\' %}\n        {{- messages[0][\'content\'] }}\n    {%- else %}\n        {{- \'Please reason step by step, and put your final answer within \\\\boxed{}.\' }}\n    {%- endif %}\n    {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}\n    {%- for tool in tools %}\n        {{- "\\n" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}\n    {%- if messages[0][\'role\'] == \'system\' %}\n        {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}\n    {%- else %}\n  

In [None]:
from typing import Tuple


class ExpertRouter(nn.Module):
    """
    Router network that determines which expert to use for a given input
    """
    def __init__(self, hidden_size: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(hidden_size, num_experts)
        
    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Calculate routing scores
        routing_logits = self.router(hidden_states)  # [batch_size, seq_len, num_experts]
        
        # Get top-k experts
        routing_weights, expert_indices = torch.topk(
            F.softmax(routing_logits, dim=-1), 
            self.top_k, 
            dim=-1
        )
        
        # Normalize the routing weights
        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
        
        return routing_weights, expert_indices

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Tuple, Optional, Union
from dataclasses import dataclass
from transformers import PreTrainedModel, Qwen2ForCausalLM, Qwen2Config, Qwen2Model
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging

logger = logging.get_logger(__name__)

class MoERouter(nn.Module):
    """Router module that selects which experts to use for each token."""
    def __init__(self, hidden_size: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.top_k = min(top_k, num_experts)  # Ensure top_k doesn't exceed num_experts
        self.router = nn.Linear(hidden_size, num_experts, bias=False)
        
    def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_states: [batch_size, seq_len, hidden_size]
        Returns:
            routing_weights: [batch_size, seq_len, top_k]
            expert_indices: [batch_size, seq_len, top_k]
        """
        # Get routing logits
        routing_logits = self.router(hidden_states)  # [batch_size, seq_len, num_experts]
        
        # Apply softmax to get routing probabilities
        routing_probs = F.softmax(routing_logits, dim=-1)
        
        # Get top-k experts and their weights
        routing_weights, expert_indices = torch.topk(routing_probs, self.top_k, dim=-1)
        
        # Normalize weights to sum to 1
        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
        
        return routing_weights, expert_indices

class Qwen2MoELayer(nn.Module):
    """MoE layer that replaces a standard transformer layer."""
    def __init__(
        self,
        expert_models: List[Qwen2ForCausalLM],
        layer_idx: int,
        hidden_size: int = 1536,
        num_experts: int = 3,
        top_k: int = 2,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_experts = num_experts
        self.top_k = top_k
        self.layer_idx = layer_idx
        
        # Extract specific layer from each expert model
        self.expert_layers = nn.ModuleList([
            expert.model.layers[layer_idx] for expert in expert_models
        ])
        
        # Router for selecting experts
        self.router = MoERouter(hidden_size, num_experts, top_k)
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
    ) -> Tuple:
        # Route to experts
        routing_weights, expert_indices = self.router(hidden_states)
        
        # Prepare output tensor
        batch_size, seq_len, hidden_dim = hidden_states.shape
        output = torch.zeros_like(hidden_states)
        
        # Process with each selected expert
        for k in range(self.top_k):
            for expert_idx in range(self.num_experts):
                # Create mask for tokens routed to this expert
                expert_mask = (expert_indices[..., k] == expert_idx).float().unsqueeze(-1)
                
                # Skip if no tokens are routed to this expert
                if expert_mask.sum() == 0:
                    continue
                
                # Get weights for this expert
                weight = routing_weights[..., k].unsqueeze(-1) * expert_mask
                
                # Forward through expert layer
                expert_output = self.expert_layers[expert_idx](
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )[0]  # Get only the hidden states
                
                # Weight and accumulate output
                output += expert_output * weight
        
        # If using cache, need to return past_key_value
        if use_cache:
            return (output, past_key_value)
        return (output,)

class Qwen2MoEModel(nn.Module):
    """MoE version of Qwen2Model with specialized expert layers."""
    def __init__(
        self,
        expert_models: List[Qwen2ForCausalLM],
        config: Qwen2Config,
        moe_layers: List[int] = None,  # Layer indices to replace with MoE
    ):
        super().__init__()
        self.config = config
        self.num_experts = len(expert_models)
        
        # If no MoE layers specified, use every other layer starting with layer 1
        if moe_layers is None:
            moe_layers = list(range(1, config.num_hidden_layers, 2))
        self.moe_layers = set(moe_layers)
        
        # Use the embedding layer from the first expert
        self.embed_tokens = expert_models[0].model.embed_tokens
        
        # Create a mix of standard and MoE layers
        self.layers = nn.ModuleList()
        for i in range(config.num_hidden_layers):
            if i in self.moe_layers:
                # MoE layer
                self.layers.append(
                    Qwen2MoELayer(
                        expert_models=expert_models,
                        layer_idx=i,
                        hidden_size=config.hidden_size,
                        num_experts=self.num_experts,
                        top_k=2,
                    )
                )
            else:
                # Standard layer (from first expert)
                self.layers.append(expert_models[0].model.layers[i])
        
        # Normalization layer from first expert
        self.norm = expert_models[0].model.norm
        
        # Copy rotary embeddings from first expert
        self.rotary_emb = expert_models[0].model.rotary_emb
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds")
        
        # Create embeddings if not provided
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        
        hidden_states = inputs_embeds
        
        # Initialize variables for outputs
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None
        
        # Forward through all layers
        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            
            past_key_value = past_key_values[idx] if past_key_values is not None else None
            
            # Layer forward pass
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
            
            hidden_states = layer_outputs[0]
            
            if use_cache:
                next_decoder_cache += (layer_outputs[1],)
            
            if output_attentions and idx not in self.moe_layers:
                all_self_attns += (layer_outputs[1 if use_cache else 1],)
        
        # Final layer norm
        hidden_states = self.norm(hidden_states)
        
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states += (hidden_states,)
        
        next_cache = next_decoder_cache if use_cache else None
        
        if return_dict:
            return {
                "last_hidden_state": hidden_states,
                "past_key_values": next_cache,
                "hidden_states": all_hidden_states,
                "attentions": all_self_attns,
            }
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)

class Qwen2MoEForCausalLM(PreTrainedModel):
    """
    Qwen2 model with Mixture of Experts for causal language modeling.
    """
    config_class = Qwen2Config
    
    def __init__(
        self,
        chat_model_path: str,
        code_model_path: str,
        math_model_path: str,
        config: Optional[Qwen2Config] = None,
        moe_layers: List[int] = None,
        device_map="auto",
    ):
        # Load expert models with appropriate device mapping for large models
        expert_models = []
        for i, path in enumerate([chat_model_path, code_model_path, math_model_path]):
            # Use different device mapping strategy for each model to prevent OOM
            if device_map == "auto":
                # Load each expert on a different CUDA device if multi-GPU setup or use disk offloading
                expert_models.append(Qwen2ForCausalLM.from_pretrained(
                    path,
                    device_map="auto",  # Automatic device mapping
                    torch_dtype=torch.float16,  # Use half precision to save memory
                    low_cpu_mem_usage=True,  # Lower CPU memory usage during loading
                ))
            else:
                # Use the specified device map
                expert_models.append(Qwen2ForCausalLM.from_pretrained(
                    path,
                    device_map=device_map,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True,
                ))
        
        # Use config from first model if not provided
        if config is None:
            config = expert_models[0].config
            
        # Store the model paths in the config
        config.expert_model_paths = {
            'chat': chat_model_path,
            'code': code_model_path,
            'math': math_model_path
        }
        
        # Store MoE layers in config
        if moe_layers is not None:
            config.moe_layers = moe_layers
        
        super().__init__(config)
        
        # Create the MoE model
        self.model = Qwen2MoEModel(
            expert_models=expert_models,
            config=config,
            moe_layers=moe_layers,
        )
        
        # Use LM head from first expert
        self.lm_head = expert_models[0].lm_head
        
        # Initialize weights
        self.post_init()
    
    def get_input_embeddings(self):
        return self.model.embed_tokens
    
    def set_input_embeddings(self, value):
        self.model.embed_tokens = value
    
    def get_output_embeddings(self):
        return self.lm_head
    
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings
    
    def set_decoder(self, decoder):
        self.model = decoder
    
    def get_decoder(self):
        return self.model
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # Forward through model
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )
        
        hidden_states = outputs["last_hidden_state"]
        
        # Get logits
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output
        
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.get("past_key_values"),
            hidden_states=outputs.get("hidden_states"),
            attentions=outputs.get("attentions"),
        )
    
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]
        
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}
        
        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs
    
    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
            )
        return reordered_past
        
    def save_pretrained(self, save_directory, safe_serialization=False, max_shard_size="10GB", **kwargs):
        """
        Save the model using HF's PreTrainedModel.save_pretrained with safe_serialization=False
        to avoid the shared weights error. Also supports model sharding for large models.
        
        Args:
            save_directory: Directory to save the model to
            safe_serialization: Set to False to avoid shared weights error
            max_shard_size: Maximum size for each model shard (e.g., "10GB", "5GB")
            **kwargs: Additional arguments to pass to save_pretrained
        """
        # Explicitly set safe_serialization to False to handle shared weights
        # Add max_shard_size to handle large models
        super().save_pretrained(
            save_directory, 
            safe_serialization=False,
            max_shard_size=max_shard_size,
            **kwargs
        )

In [20]:

OUTPUT_PATH = "./qwen2_moe_model"

# Create the MoE model
# By default, every other layer will be replaced with MoE layers
print("Creating Qwen2MoEForCausalLM model...")
model = Qwen2MoEForCausalLM(
    chat_model_path=chat_model_name,
    code_model_path=chat_model_name,
    math_model_path=math_model_name,
    # Optional: Specify which layers to replace with MoE
    # moe_layers=[2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
)

Creating Qwen2MoEForCausalLM model...


In [22]:
OUTPUT_PATH = "./qwen2_5_moe_model"
model.save_pretrained(OUTPUT_PATH)
tokenizer = AutoTokenizer.from_pretrained(chat_model_name, trust_remote_code=True)
tokenizer.save_pretrained(OUTPUT_PATH)

('./qwen2_5_moe_model\\tokenizer_config.json',
 './qwen2_5_moe_model\\special_tokens_map.json',
 './qwen2_5_moe_model\\vocab.json',
 './qwen2_5_moe_model\\merges.txt',
 './qwen2_5_moe_model\\added_tokens.json',
 './qwen2_5_moe_model\\tokenizer.json')