In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen2MoeForCausalLM
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
chat_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
math_model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
coder_model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
# tokenizer = AutoTokenizer.from_pretrained(chat_model_name)
# chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, device_map="auto")
# math_model = AutoModelForCausalLM.from_pretrained(math_model_name, device_map="auto")
# coder_model = AutoModelForCausalLM.from_pretrained(coder_model_name, device_map="auto")

In [4]:
chat_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1536,), eps=1e-06)
    (rotary_emb): Qw

In [5]:
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import Qwen2MoeForCausalLM, Qwen2MoeConfig, Qwen2Model, Qwen2Config

In [None]:
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2Attention,
    Qwen2RMSNorm,
    Qwen2RotaryEmbedding,
    Qwen2MLP,
)
from transformers import Qwen2MoeConfig
import torch.nn as nn


class Qwen2_5MoEConfig(Qwen2MoeConfig):
    model_type = "qwen2_5moe"

    def __init__(
        self,
        num_experts=8,
        top_k=2,
        capacity_factor=1.5,
        aux_loss_weight=0.01,
        router_jitter=0.01,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.aux_loss_weight = aux_loss_weight
        self.router_jitter = router_jitter

        self.dtype = kwargs.get("dtype", torch.float32)
        self.dropout_rate = kwargs.get("dropout_rate", 0.01)


class Qwen2_5MoEExpertRouter(torch.nn.Module):
    """
    Mixture of Experts Router Layer

    Takes attention outputs as input and routes to top-k MLPs (experts)
    using a learned routing mechanism.
    """

    def __init__(
        self,
        input_dim,  # Dimension of input (from attention layer)
        num_experts,  # Total number of experts available
        top_k=1,  # Number of experts to route each token to
        capacity_factor=1.5,  # Scaling factor for expert capacity
        aux_loss_weight=0.01,  # Weight for auxiliary load balancing loss
        router_jitter=0.01,  # Optional noise to add during training
        dtype=torch.float32,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.aux_loss_weight = aux_loss_weight
        self.router_jitter = router_jitter
        self.dtype = dtype

        # Router projection layer: maps input to expert selection logits
        self.router = nn.Linear(input_dim, num_experts, bias=False, dtype=dtype)

        # Initialize with small weights to encourage equal expert usage early in training
        nn.init.normal_(self.router.weight, mean=0.0, std=0.01)

    def forward(self, inp, training=True):
        """
        Forward pass for the router

        Args:
            inp: Tensor of shape [batch_size, seq_len, input_dim] from attention layer
            training: Whether the model is in training mode

        Returns:
            dispatch_tensor: Sparse tensor for dispatching inputs to experts
            combine_tensor: Sparse tensor for combining expert outputs
            router_logits: Raw router logits
            aux_loss: Load balancing auxiliary loss
        """
        # Get shape information
        batch_size, seq_len, _ = inp.shape
        num_tokens = batch_size * seq_len

        # Reshape for routing
        inp_reshaped = inp.reshape(
            -1, self.input_dim
        )  # [batch_size * seq_len, input_dim]

        # Get router logits
        router_logits = self.router(inp_reshaped)  # [batch_size * seq_len, num_experts]

        # Add jitter noise during training for stability
        if training and self.router_jitter > 0:
            router_logits += torch.randn_like(router_logits) * self.router_jitter

        # Calculate expert capacity: how many tokens can be routed to each expert
        # We scale by capacity_factor to allow for some experts to receive more tokens
        capacity = int(
            self.capacity_factor * num_tokens * self.top_k / self.num_experts
        )

        # Convert router logits to probabilities using softmax
        router_probs = F.softmax(
            router_logits, dim=-1
        )  # [batch_size * seq_len, num_experts]

        # Get top-k experts and their routing probabilities
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)

        # Normalize the top-k probabilities
        top_k_probs_sum = top_k_probs.sum(dim=-1, keepdim=True)
        top_k_probs_normalized = top_k_probs / top_k_probs_sum

        # Create mask for valid routing
        # Each token routes to top_k experts
        expert_mask = torch.zeros(
            num_tokens, self.num_experts, device=router_logits.device, dtype=torch.bool
        )

        # Create indices for scatter operation
        token_indices = (
            torch.arange(num_tokens, device=router_logits.device)
            .unsqueeze(1)
            .expand(-1, self.top_k)
        )

        # Populate the expert mask
        expert_mask.scatter_(1, top_k_indices, True)

        # Calculate auxiliary load balancing loss
        # We want to encourage all experts to be used equally
        # 1. Compute the fraction of router probability assigned to each expert
        router_prob_per_expert = router_probs.mean(0)

        # 2. Compute auxiliary loss: minimize the variance in expert utilization
        aux_loss = torch.mean(
            self.num_experts * router_prob_per_expert * router_prob_per_expert
        )

        # Create dispatch and combine tensors
        # These will be used to route inputs to experts and combine expert outputs

        # Create dispatch mask tracking which tokens go to which experts with what weights
        dispatch_tensor = torch.zeros(
            num_tokens, self.num_experts, device=router_logits.device, dtype=self.dtype
        )

        # For each token and its top-k experts, set the corresponding weight
        for token_idx in range(num_tokens):
            for k in range(self.top_k):
                expert_idx = top_k_indices[token_idx, k].item()
                prob = top_k_probs_normalized[token_idx, k].item()
                dispatch_tensor[token_idx, expert_idx] = prob

        # The combine tensor is the same as the dispatch tensor in this implementation
        # Some implementations might use different weights for combining
        combine_tensor = dispatch_tensor.clone()

        return {
            "dispatch_tensor": dispatch_tensor,
            "combine_tensor": combine_tensor,
            "router_logits": router_logits,
            "router_probs": router_probs,
            "aux_loss": aux_loss,
            "top_k_indices": top_k_indices,
            "top_k_probs": top_k_probs_normalized,
        }


class Qwen2_5MoEDecoderLayer(torch.nn.Module):
    def __init__(self, config: Qwen2_5MoEConfig, layer_idx: int):
        super().__init__()
        # In Qwen2_5MoEDecoderLayer.__init__
        self.input_dim = config.hidden_size
        self.output_dim = config.hidden_size
        self.hidden_size = config.hidden_size
        self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
        # self.mlp = Qwen2MLP(config)
        self.router = Qwen2_5MoEExpertRouter(
            input_dim=config.hidden_size,
            num_experts=config.num_experts,
            top_k=config.top_k,
            capacity_factor=config.capacity_factor,
            aux_loss_weight=config.aux_loss_weight,
            router_jitter=config.router_jitter,
        )

        # self.experts = nn.ModuleList([
        #     nn.Sequential(
        #         nn.Linear(input_dim, expert_dim, dtype=dtype),
        #         nn.GELU(),
        #         nn.Dropout(expert_dropout),
        #         nn.Linear(expert_dim, output_dim, dtype=dtype)
        #     )
        #     for _ in range(num_experts)
        # ])
        self.experts = nn.ModuleList(
            [
                Qwen2MLP(
                    config,
                )
                for _ in range(config.num_experts)
            ]
        )

        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
        if config.sliding_window and config._attn_implementation != "flash_attention_2":
            print(
                f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
                "unexpected results may be encountered."
            )
        # Output projection layer for better integration with the next layer
        # self.output_proj = nn.Linear(output_dim, output_dim, dtype=dtype)

    def forward(self, inp, training=True):
        """
        Forward pass for the MoE layer

        Args:
            inp: Tensor of shape [batch_size, seq_len, input_dim] from attention layer
            training: Whether the model is in training mode

        Returns:
            outputs: Tensor of shape [batch_size, seq_len, output_dim]
            aux_loss: Load balancing auxiliary loss
        """
        batch_size, seq_len, _ = inp.shape
        num_tokens = batch_size * seq_len

        # Get routing information
        router_outputs = self.router(inp, training=training)
        dispatch_tensor = router_outputs["dispatch_tensor"]
        combine_tensor = router_outputs["combine_tensor"]
        aux_loss = router_outputs["aux_loss"]

        # Reshape input for expert processing
        inp_reshaped = inp.reshape(
            -1, self.input_dim
        )  # [batch_size * seq_len, input_dim]

        # Initialize expert outputs
        expert_outputs = torch.zeros(
            num_tokens, self.output_dim, device=inp.device, dtype=inp.dtype
        )

        # Process each expert
        for expert_idx, expert in enumerate(self.experts):
            # Get tokens routed to this expert
            expert_mask = dispatch_tensor[:, expert_idx] > 0
            if not expert_mask.any():
                continue

            # Select tokens for this expert
            expert_inputs = inp_reshaped[expert_mask]

            # Get expert weights for these tokens
            expert_weights = dispatch_tensor[expert_mask, expert_idx].unsqueeze(1)

            # Process inputs with the expert
            processed = expert(expert_inputs)

            # Weight the outputs by router probabilities
            weighted_outputs = processed * expert_weights

            # Add to the total outputs
            expert_outputs[expert_mask] += weighted_outputs

        # Reshape back to original dimensions
        merged_outputs = expert_outputs.reshape(batch_size, seq_len, self.output_dim)

        # Apply final output projection to improve integration with normalization layer
        outputs = self.input_layernorm(merged_outputs)
        outputs = self.post_attention_layernorm(outputs)

        return outputs, aux_loss


from dataclasses import dataclass
from typing import Optional, Tuple

from transformers.modeling_outputs import ModelOutput


@dataclass
class MoEModelOutputWithPast(ModelOutput):
    """
    Output class for MoE models.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.
            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

        aux_loss (`torch.FloatTensor`, *optional*):
            Auxiliary load balancing loss for the MoE layers.
    """

    last_hidden_state: torch.FloatTensor
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    aux_loss: Optional[torch.FloatTensor] = None


class Qwen2_5MoEModel(PreTrainedModel):
    def __init__(self, config: Qwen2_5MoEConfig):
        super().__init__(config)
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList(
            [Qwen2_5MoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
        )
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        """
        Forward pass of the Qwen2_5MoEModel.

        Args:
            input_ids (torch.LongTensor): Indices of input sequence tokens (batch_size, seq_len)
            attention_mask (torch.Tensor): Mask to avoid attention on padding tokens (batch_size, seq_len)
            position_ids (torch.LongTensor): Indices of positions (batch_size, seq_len)
            past_key_values (tuple): Cached past key and values for faster inference
            inputs_embeds (torch.FloatTensor): Embedded inputs (batch_size, seq_len, hidden_size)
            use_cache (bool): Whether to return a cache for faster inference
            output_attentions (bool): Whether to return attention weights
            output_hidden_states (bool): Whether to return all hidden states
            return_dict (bool): Whether to return a ModelOutput or tuple

        Returns:
            BaseModelOutputWithPast or tuple: Model outputs
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # Create embedding for input tokens if not provided
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = inputs_embeds.shape[:2]

        # Create position ids if not provided
        if position_ids is None:
            position_ids = torch.arange(
                0, seq_length, dtype=torch.long, device=inputs_embeds.device
            ).unsqueeze(0)

        # Compute rotary embeddings
        cos, sin = self.rotary_emb(inputs_embeds, seq_length)

        # Process through each decoder layer
        hidden_states = inputs_embeds
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None
        all_router_logits = ()
        total_aux_loss = 0.0

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = (
                past_key_values[idx] if past_key_values is not None else None
            )

            # Process through decoder layer
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cos=cos,
                sin=sin,
                training=self.training,
            )

            hidden_states = layer_outputs[0]
            total_aux_loss += layer_outputs[1]

            # Handle caching and outputs
            if output_attentions:
                all_self_attns += (layer_outputs[2],)

            if use_cache:
                next_decoder_cache += (layer_outputs[3],)

        # Final layer norm
        hidden_states = self.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attns,
                    total_aux_loss,
                ]
                if v is not None
            )

        return MoEModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            # router_logits=all_router_logits,
            aux_loss=total_aux_loss,
        )


class Qwen2_5MoePreTrainedModel(PreTrainedModel):
    config_class = Qwen2_5MoEConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2_5MoEDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


from transformers.generation import GenerationMixin


class Qwen2_5MoEForCausalLM(Qwen2_5MoePreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config: Qwen2_5MoEConfig):
        super().__init__(config)
        self.model = Qwen2_5MoEModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def forward(
        self, input_ids, attention_mask=None, position_ids=None, head_mask=None
    ):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
        )
        hidden_states = outputs[0]
        lm_logits = self.lm_head(hidden_states)
        return lm_logits

    def generate(
        self, input_ids, attention_mask=None, position_ids=None, head_mask=None
    ):
        lm_logits = self(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
        )
        # Implement your generation logic here
        # For example, you can use beam search or sampling to generate text
        # This is a placeholder for the actual generation logic
        generated_text = torch.argmax(lm_logits, dim=-1)
        return generated_text
    
    def merge_models(self, model_name_or_path: str):
        """
        Merge the model with the specified model name or path.
        This is a placeholder for the actual merging logic.
        """
        # Implement your merging logic here
        pass

In [7]:
from transformers import Qwen2Config

chat_model_config = Qwen2Config.from_pretrained(chat_model_name)
moe_config = Qwen2_5MoEConfig(
    vocab_size=chat_model_config.vocab_size,
    hidden_size=chat_model_config.hidden_size,
    intermediate_size=chat_model_config.intermediate_size,
    num_hidden_layers=chat_model_config.num_hidden_layers,
    num_attention_heads=chat_model_config.num_attention_heads,
    num_experts=3,
    top_k=1,
    # capacity_factor=1.5,
    # aux_loss_weight=0.01,
)

moe_model = Qwen2_5MoEForCausalLM(moe_config)



In [8]:
moe_model

Qwen2_5MoEForCausalLM(
  (model): Qwen2_5MoEModel(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2_5MoEDecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=2048, bias=True)
          (v_proj): Linear(in_features=1536, out_features=2048, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (router): Qwen2_5MoEExpertRouter(
          (router): Linear(in_features=1536, out_features=3, bias=False)
        )
        (experts): ModuleList(
          (0-2): 3 x Qwen2MLP(
            (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
        )
        (

In [None]:
for layer_idx in range(len(chat_model.model.layers)):
    decoder_layer = chat_model.model.layers[layer_idx]
    print(f"Layer {decoder_layer}")

Layer Qwen2DecoderLayer(
  (self_attn): Qwen2Attention(
    (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
    (k_proj): Linear(in_features=1536, out_features=256, bias=True)
    (v_proj): Linear(in_features=1536, out_features=256, bias=True)
    (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
  )
  (mlp): Qwen2MLP(
    (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
    (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
    (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
  (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
)
Layer Qwen2DecoderLayer(
  (self_attn): Qwen2Attention(
    (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
    (k_proj): Linear(in_features=1536, out_features=256, bias=True)
    (v_proj): Linear(in_features=1536, out_features=256, bias=True)
    (o_proj): Linear(in_