In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict, Optional, Tuple, Union
import os
import logging
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MoEFFNLayer(nn.Module):
    """
    Mixture of Experts FFN layer that replaces the standard FFN in Transformer blocks.
    Implements a top-1 routing mechanism.
    """
    def __init__(self, hidden_size: int, intermediate_size: int, num_experts: int = 3):
        super().__init__()
        self.num_experts = num_experts
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        
        # Router network - determines which expert to use for each token
        self.router = nn.Linear(hidden_size, num_experts, bias=False)
        
        # Expert FFNs - each consisting of gate_proj, up_proj (for SwiGLU), and down_proj
        # Structure matches Qwen's FFN architecture
        self.experts = nn.ModuleList([
            nn.ModuleDict({
                'gate_proj': nn.Linear(hidden_size, intermediate_size, bias=False),
                'up_proj': nn.Linear(hidden_size, intermediate_size, bias=False),
                'down_proj': nn.Linear(intermediate_size, hidden_size, bias=False),
            }) for _ in range(num_experts)
        ])
        
        # Initialize router weights
        self._init_router()
    
    def _init_router(self):
        """Initialize router weights with a normal distribution."""
        nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with top-1 routing.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, hidden_size)
            
        Returns:
            Output tensor of shape (batch_size, seq_len, hidden_size)
        """
        batch_size, seq_len, hidden_size = x.shape
        
        # Reshape for efficient processing
        x_flat = x.view(-1, hidden_size)  # (batch_size * seq_len, hidden_size)
        
        # Compute routing probabilities
        router_logits = self.router(x_flat)  # (batch_size * seq_len, num_experts)
        routing_weights = F.softmax(router_logits, dim=-1)  # (batch_size * seq_len, num_experts)
        
        # Get expert indices for each token (top-1 routing)
        expert_indices = torch.argmax(routing_weights, dim=-1)  # (batch_size * seq_len)
        
        # Prepare output tensor
        final_output = torch.zeros_like(x_flat)
        
        # Efficient routing using masks
        for expert_idx in range(self.num_experts):
            # Create mask for tokens routed to this expert
            expert_mask = (expert_indices == expert_idx)
            if not expert_mask.any():
                continue
                
            # Get tokens for this expert
            expert_inputs = x_flat[expert_mask]
            
            # Pass through the expert's FFN layers (using SwiGLU activation as in Qwen)
            gate_output = self.experts[expert_idx]['gate_proj'](expert_inputs)
            up_output = self.experts[expert_idx]['up_proj'](expert_inputs)
            
            # SwiGLU activation
            intermediate_output = F.silu(gate_output) * up_output
            
            # Output projection
            expert_output = self.experts[expert_idx]['down_proj'](intermediate_output)
            
            # Route output back using the mask
            expert_routing_weights = routing_weights[expert_mask, expert_idx].unsqueeze(-1)
            final_output[expert_mask] = expert_output * expert_routing_weights
        
        # Reshape back to original dimensions
        return final_output.view(batch_size, seq_len, hidden_size)


class QwenMoEModelMerger:
    """
    Class to merge multiple Qwen models into a single Mixture of Experts model.
    """
    def __init__(self, model_paths: List[str], expert_names: List[str], output_path: str):
        """
        Initialize the merger with model paths and configurations.
        
        Args:
            model_paths: List of paths to the models to merge
            expert_names: Names for each expert (for logging and reference)
            output_path: Path to save the merged model
        """
        if len(model_paths) != len(expert_names):
            raise ValueError("Number of model paths must match number of expert names")
        
        self.model_paths = model_paths
        self.expert_names = expert_names
        self.output_path = output_path
        self.num_experts = len(model_paths)
        self.models = []
        self.tokenizer = None
        self.moe_model = None
        
        logger.info(f"Initializing QwenMoEModelMerger with {self.num_experts} experts")
    
    def load_models(self):
        """Load all source models."""
        logger.info("Loading source models...")
        
        for i, model_path in enumerate(self.model_paths):
            logger.info(f"Loading model {i+1}/{self.num_experts}: {self.expert_names[i]} from {model_path}")
            model = AutoModelForCausalLM.from_pretrained(
                model_path, 
                torch_dtype=torch.float16,  # Use fp16 to reduce memory usage
                device_map="auto"
            )
            self.models.append(model)
            
            # Load tokenizer from the first model (assuming compatible tokenizers)
            if i == 0:
                logger.info(f"Loading tokenizer from {model_path}")
                self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        # Verify models have compatible architectures
        self._validate_models()
        
        return self
    
    def _validate_models(self):
        """Check if all models have compatible architectures."""
        logger.info("Validating model architectures...")
        
        ref_config = self.models[0].config
        ref_state_dict = self.models[0].state_dict()
        
        for i in range(1, self.num_experts):
            config = self.models[i].config
            
            # Check key parameters
            if config.hidden_size != ref_config.hidden_size:
                raise ValueError(f"Model {i} has different hidden_size: {config.hidden_size} vs {ref_config.hidden_size}")
            if config.num_hidden_layers != ref_config.num_hidden_layers:
                raise ValueError(f"Model {i} has different num_hidden_layers: {config.num_hidden_layers} vs {ref_config.num_hidden_layers}")
            if config.intermediate_size != ref_config.intermediate_size:
                raise ValueError(f"Model {i} has different intermediate_size: {config.intermediate_size} vs {ref_config.intermediate_size}")
                
            # Check model structure
            state_dict = self.models[i].state_dict()
            if set(state_dict.keys()) != set(ref_state_dict.keys()):
                diff = set(state_dict.keys()).symmetric_difference(set(ref_state_dict.keys()))
                logger.warning(f"Model {i} has different state_dict keys: {diff}")
                # This is a warning, not an error, as some models might have additional keys
            
        logger.info("All models have compatible architectures")
    
    def create_moe_model(self):
        """Create a new model with MoE layers."""
        logger.info("Creating MoE model...")
        
        # Start with a copy of the first model
        self.moe_model = AutoModelForCausalLM.from_pretrained(
            self.model_paths[0],
            torch_dtype=torch.float16,
            device_map="auto"
        )
        
        # Get reference config
        config = self.models[0].config
        
        # Replace FFN layers with MoE layers
        for layer_idx in range(config.num_hidden_layers):
            # Create MoE FFN layer
            moe_ffn = MoEFFNLayer(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                num_experts=self.num_experts
            )
            
            # Copy expert weights from source models
            for expert_idx, model in enumerate(self.models):
                # Extract FFN weights from source model
                layer = model.model.layers[layer_idx]
                
                # Copy weights to the corresponding expert
                moe_ffn.experts[expert_idx]['gate_proj'].weight.data.copy_(
                    layer.mlp.gate_proj.weight.data
                )
                moe_ffn.experts[expert_idx]['up_proj'].weight.data.copy_(
                    layer.mlp.up_proj.weight.data
                )
                moe_ffn.experts[expert_idx]['down_proj'].weight.data.copy_(
                    layer.mlp.down_proj.weight.data
                )
            
            # Replace the original FFN with the MoE FFN
            # Store original module for reference
            original_ffn = self.moe_model.model.layers[layer_idx].mlp
            
            # Create a new MoE module and replace
            self.moe_model.model.layers[layer_idx].mlp = moe_ffn
        
        logger.info("MoE model created successfully")
        return self
    
    def save_model(self):
        """Save the merged model."""
        if self.moe_model is None:
            raise ValueError("No MoE model created yet. Call create_moe_model() first.")
        
        logger.info(f"Saving MoE model to {self.output_path}")
        os.makedirs(self.output_path, exist_ok=True)
        
        # Save model
        self.moe_model.save_pretrained(self.output_path)
        
        # Save tokenizer
        if self.tokenizer:
            self.tokenizer.save_pretrained(self.output_path)
        
        # Save merged model config with MoE details
        expert_info = {f"expert_{i}": name for i, name in enumerate(self.expert_names)}
        moe_config = {
            "moe_type": "Top1Router",
            "num_experts": self.num_experts,
            "experts": expert_info,
            "source_models": self.model_paths
        }
        
        # Update config with MoE details
        self.moe_model.config.update({"moe_config": moe_config})
        self.moe_model.config.save_pretrained(self.output_path)
        
        logger.info(f"Model successfully saved to {self.output_path}")
        return self
    
    def test_model(self, test_prompts: List[str]):
        """Run a simple test on the merged model."""
        if self.moe_model is None:
            raise ValueError("No MoE model created yet. Call create_moe_model() first.")
        
        logger.info("Testing MoE model with sample prompts")
        
        for i, prompt in enumerate(test_prompts):
            logger.info(f"Test prompt {i+1}: {prompt}")
            
            inputs = self.tokenizer(prompt, return_tensors="pt")
            
            with torch.no_grad():
                # Generate response
                outputs = self.moe_model.generate(
                    **inputs,
                    max_new_tokens=100,
                    temperature=0.7,
                    do_sample=True
                )
                
                response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                logger.info(f"Response: {response}")
        
        return self
    
    def merge(self):
        """Execute the full merging process."""
        return (
            self.load_models()
            .create_moe_model()
            .save_model()
        )


# Example usage
if __name__ == "__main__":
    # Paths to the three model checkpoints
    model_paths = [
        "Qwen/Qwen2.5-1.5B-Instruct",
        "Qwen/Qwen2.5-Coder-1.5B-Instruct",
        "Qwen/Qwen2.5-Math-1.5B-Instruct",
    ]
    
    expert_names = ["Chat", "Coder", "Math"]
    output_path = "qwen-moe-merged"
    
    # Create and run the merger
    merger = QwenMoEModelMerger(model_paths, expert_names, output_path)
    merger.merge()
    
   


In [20]:
merger.moe_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): MoEFFNLayer(
          (router): Linear(in_features=1536, out_features=3, bias=False)
          (experts): ModuleList(
            (0-2): 3 x ModuleDict(
              (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
              (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
              (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
            )
          )
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=

In [None]:
model = AutoModelForCausalLM.from_pretrained("./qwen-moe-merged")
# Test the merged model
test_prompts = [
    # "Can you explain how quantum computing works?",  # Chat task
    # "Write a Python function to find prime numbers.",  # Coding task
    "Solve the equation: 3x^2 - 5x + 2 = 0"  # Math task
]
for prompt in test_prompts

2025-04-02 14:54:42,328 - __main__ - INFO - Testing MoE model with sample prompts
2025-04-02 14:54:42,329 - __main__ - INFO - Test prompt 1: Solve the equation: 3x^2 - 5x + 2 = 0


RuntimeError: Tensor for argument weight is on cpu but expected on mps

In [None]:
import torch
import os
from transformers import AutoTokenizer

class QwenMoEModel:
    def __init__(self, model_path, device="cuda"):
        """
        Load a previously merged Qwen MoE model
        
        Args:
            model_path: Path to the saved MoE model
            device: Device to load the model on ('cuda' or 'cpu')
        """
        self.device = device
        self.model_path = model_path
        
        # Load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        # Load the model architecture and state dict
        checkpoint = torch.load(os.path.join(model_path, "pytorch_model.bin"), 
                               )
        
        # Recreate the model structure
        self.model = self._build_model_from_checkpoint(checkpoint)
        self.model.to(device)
        self.model.eval()  # Set to evaluation mode
        
    def _build_model_from_checkpoint(self, checkpoint):
        """Rebuild model architecture from the saved checkpoint"""
        config = checkpoint['config']
        
        # Define the Router class again
        class Router(torch.nn.Module):
            def __init__(self, hidden_size, num_experts):
                super().__init__()
                self.router = torch.nn.Linear(hidden_size, num_experts)
                
            def forward(self, hidden_states):
                router_logits = self.router(hidden_states)
                expert_weights, expert_indices = torch.topk(router_logits, k=1, dim=-1)
                return expert_weights, expert_indices
        
        # Load the individual expert models
        # This assumes the expert paths are stored in the config or can be determined
        expert_paths = self._get_expert_paths()
        experts = []
        for path in expert_paths:
            from transformers import AutoModelForCausalLM
            model = AutoModelForCausalLM.from_pretrained(
                path, 
                torch_dtype=torch.float16,
                # device_map=self.device
            )
            experts.append(model)
        
        # Define the MoE model class
        class QwenMoE(torch.nn.Module):
            def __init__(self, experts, hidden_size):
                super().__init__()
                self.experts = torch.nn.ModuleList(experts)
                self.router = Router(hidden_size, len(experts))
                self.num_experts = len(experts)
                
            def forward(self, input_ids, attention_mask=None):
                # Get embeddings from the first model
                with torch.no_grad():
                    hidden_states = self.experts[0].transformer.word_embeddings(input_ids)
                
                # Route to the appropriate expert
                _, expert_indices = self.router(hidden_states[:, 0])
                
                # Process each input with its assigned expert
                outputs = []
                for i in range(input_ids.shape[0]):
                    expert_idx = expert_indices[i].item()
                    output = self.experts[expert_idx](
                        input_ids=input_ids[i:i+1],
                        attention_mask=attention_mask[i:i+1] if attention_mask is not None else None
                    )
                    outputs.append(output)
                
                # Combine outputs
                combined_logits = torch.cat([o.logits for o in outputs], dim=0)
                
                # Return in the expected format
                return type(outputs[0])(logits=combined_logits)
        
        # Create model instance
        moe_model = QwenMoE(experts, config.hidden_size)
        
        # Load state dict from checkpoint
        moe_model.load_state_dict(checkpoint['model_state_dict'])
        
        return moe_model
    
    
    def generate(self, prompt, max_length=100, temperature=0.7, top_p=0.9):
        """
        Generate text based on the prompt
        
        Args:
            prompt: Input text prompt
            max_length: Maximum length of generated sequence
            temperature: Sampling temperature
            top_p: Nucleus sampling parameter
            
        Returns:
            Generated text
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            output_ids = self.model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_length=max_length,
                do_sample=True,
                temperature=temperature,
                top_p=top_p
            )
        
        # Decode the generated tokens
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return generated_text
    
    def get_logits(self, prompt):
        """
        Get the output logits for a given prompt
        
        Args:
            prompt: Input text prompt
            
        Returns:
            Model logits
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        return outputs.logits

In [15]:
model = QwenMoEModel(
    "./qwen_2.5_3x1.5b_moe",
)

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL transformers.models.qwen2.configuration_qwen2.Qwen2Config was not an allowed global by default. Please use `torch.serialization.add_safe_globals([Qwen2Config])` or the `torch.serialization.safe_globals([Qwen2Config])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [18]:
moe_model

QwenMoE(
  (experts): ModuleList(
    (0-2): 3 x Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0-27): 28 x Qwen2DecoderLayer(
            (self_attn): Qwen2Attention(
              (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
              (k_proj): Linear(in_features=1536, out_features=256, bias=True)
              (v_proj): Linear(in_features=1536, out_features=256, bias=True)
              (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
            )
            (mlp): Qwen2MLP(
              (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
              (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
              (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
            (post_attention_lay

In [14]:
# count number of params
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Original model parameters: {count_parameters(model)}")
print(f"MoE model parameters: {count_parameters(moe_model)/1_000_000_000:.2f}B")

Original model parameters: 494032768
MoE model parameters: 3.00B


In [None]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

In [16]:
messages = [
    {"role": "user", "content": "Hello! How are you?"}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate the response
# outputs = moe_model.generate(
#     model_inputs.input_ids,
#     max_new_tokens=512
# )

# 

In [18]:
# decode the generated response
outputs = moe_model.generate(
    model_inputs.input_ids,
    max_new_tokens=512,
    do_sample=True,
    top_k=50,
    top_p=0.95,
    temperature=0.3,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response

KeyboardInterrupt: 