In [45]:
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Optional, Unpack
from transformers.cache_utils import Cache, DynamicCache
from transformers import AutoConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils.generic import check_model_inputs
from transformers.modeling_layers import (
    GradientCheckpointingLayer,
)

from typing import Optional, Union

from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2MLP,
    Qwen2Attention,
    Qwen2RMSNorm,
    Qwen2RotaryEmbedding,
)

from transformers.generation import GenerationMixin

from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask

In [46]:
tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen2.5-0.5B",
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

config = AutoConfig.from_pretrained("Qwen/Qwen2.5-0.5B")
config._attn_implementation = "sdpa"
print(model)
print("-" * 100)
print(config)


loading file vocab.json from cache at C:\Users\luequ\.cache\huggingface\hub\models--Qwen--Qwen2.5-0.5B\snapshots\060db6499f32faf8b98477b0a26969ef7d8b9987\vocab.json
loading file merges.txt from cache at C:\Users\luequ\.cache\huggingface\hub\models--Qwen--Qwen2.5-0.5B\snapshots\060db6499f32faf8b98477b0a26969ef7d8b9987\merges.txt
loading file tokenizer.json from cache at C:\Users\luequ\.cache\huggingface\hub\models--Qwen--Qwen2.5-0.5B\snapshots\060db6499f32faf8b98477b0a26969ef7d8b9987\tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at C:\Users\luequ\.cache\huggingface\hub\models--Qwen--Qwen2.5-0.5B\snapshots\060db6499f32faf8b98477b0a26969ef7d8b9987\tokenizer_config.json
loading file chat_template.jinja from cache at None
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
loading configuration file config.

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2

In [47]:
class Qwen2DecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)

        self.mlp = Qwen2MLP(config)
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.attention_type = config.layer_types[layer_idx]

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states
        
class Qwen2PreTrainedModel(PreTrainedModel):
    config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Qwen2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True

    _can_compile_fullgraph = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": Qwen2DecoderLayer,
        "attentions": Qwen2Attention,
    }



class Qwen2Model(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        self.gradient_checkpointing = False
        self.has_sliding_layers = "sliding_attention" in self.config.layer_types
        self.post_init()

    @check_model_inputs
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache(config=self.config)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
            }
            # The sliding window alternating layers are not always activated depending on the config
            if self.has_sliding_layers:
                causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                position_embeddings=position_embeddings,
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                cache_position=cache_position,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )


class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
    _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs: Unpack[TransformersKwargs],
    ) -> CausalLMOutputWithPast:
        r"""
        Example:

        ```python
        >>> from transformers import AutoTokenizer, Qwen2ForCausalLM

        >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [48]:
custom_model = Qwen2ForCausalLM(config)
custom_model.load_state_dict(model.state_dict(), strict=True)
# Save inv_freq before conversion
inv_freq_backup = custom_model.model.rotary_emb.inv_freq.clone()
custom_model = custom_model.to(dtype=torch.bfloat16, device="cuda")
# Restore inv_freq to float32
custom_model.model.rotary_emb.inv_freq = inv_freq_backup.to("cuda")


Generate config GenerationConfig {
  "bos_token_id": 151643,
  "eos_token_id": 151643
}



In [49]:
def load_hf_weights(custom_model, hf_model):
    """
    Load weights from a HuggingFace model into the custom model.
    
    Args:
        custom_model: Custom Qwen2ForCausalLM instance
        hf_model: HuggingFace AutoModelForCausalLM instance
    
    Returns:
        custom_model with loaded weights
    """
    hf_state = hf_model.state_dict()
    
    # Load state dict with strict=True to ensure exact match
    custom_model.load_state_dict(hf_state, strict=True)
    
    print(f"Successfully loaded {len(hf_state)} parameter tensors")
    return custom_model

# Load weights and convert to bfloat16
custom_model = load_hf_weights(custom_model, model)

# Save inv_freq before dtype conversion
inv_freq = custom_model.model.rotary_emb.inv_freq.clone()

# Convert model to bfloat16
custom_model = custom_model.to(dtype=torch.bfloat16, device="cuda")

# Restore inv_freq to float32 (critical for numerical accuracy)
custom_model.model.rotary_emb.inv_freq = inv_freq.to(device="cuda", dtype=torch.float32)

Successfully loaded 291 parameter tensors


In [50]:
prompt = "Prove that there multiplication is a continuous map using a delta epsilon argument"
inputs = tokenizer(prompt, return_tensors="pt")
inputs.input_ids = inputs.input_ids.to(model.device)

In [51]:
generate_ids = custom_model.generate(inputs.input_ids, max_new_tokens=120)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


In [52]:
tokenizer.decode(generate_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)

'Prove that there multiplication is a continuous map using a delta epsilon argument. I am trying to prove that the multiplication map $m: \\mathbb{R} \\times \\mathbb{R} \\to \\mathbb{R}$ is continuous. I am trying to use a delta epsilon argument, but I am not sure how to do it. I know that the multiplication map is continuous if and only if it is continuous at every point in $\\mathbb{R} \\times \\mathbb{R}$, but I am not sure how to use this to prove that it is continuous at every point in $\\mathbb{R} \\times \\'

In [53]:
# === Forward Pass Equivalence Test ===

def test_forward_equivalence(custom_model, hf_model, tokenizer, test_prompts=None):
    """
    Test that both models produce identical outputs for the same inputs.
    """
    if test_prompts is None:
        test_prompts = [
            "Hello world",
            "The quick brown fox jumps over the lazy dog.",
            "In mathematics, a prime number is",
            "def fibonacci(n):",
        ]
    
    print("=" * 60)
    print("TEST: Forward Pass Equivalence")
    print("=" * 60)
    
    custom_model.eval()
    hf_model.eval()
    
    all_passed = True
    
    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs.input_ids.to(hf_model.device)
        
        with torch.no_grad():
            # HuggingFace model output
            hf_output = hf_model(input_ids)
            hf_logits = hf_output.logits
            
            # Custom model output
            custom_output = custom_model(input_ids)
            custom_logits = custom_output["logits"]
        
        # Compare logits
        max_diff = (hf_logits.float() - custom_logits.float()).abs().max().item()
        mean_diff = (hf_logits.float() - custom_logits.float()).abs().mean().item()
        
        # Use a reasonable tolerance for bfloat16
        is_close = torch.allclose(hf_logits.float(), custom_logits.float(), atol=1e-4, rtol=1e-3)
        
        status = "✓" if is_close else "✗"
        print(f"{status} Prompt: '{prompt[:40]}...'")
        print(f"   Max diff: {max_diff:.2e}, Mean diff: {mean_diff:.2e}")
        
        if not is_close:
            all_passed = False
    
    print("\n" + "=" * 60)
    if all_passed:
        print("✓ Forward pass equivalence test PASSED!")
    else:
        print("✗ Forward pass equivalence test FAILED!")
    print("=" * 60)
    
    return all_passed, hf_logits, custom_logits

# Run forward pass test
_, hf_logits, custom_logits = test_forward_equivalence(custom_model, model, tokenizer)


TEST: Forward Pass Equivalence
✓ Prompt: 'Hello world...'
   Max diff: 0.00e+00, Mean diff: 0.00e+00
✓ Prompt: 'The quick brown fox jumps over the lazy ...'
   Max diff: 0.00e+00, Mean diff: 0.00e+00
✓ Prompt: 'In mathematics, a prime number is...'
   Max diff: 0.00e+00, Mean diff: 0.00e+00
✓ Prompt: 'def fibonacci(n):...'
   Max diff: 0.00e+00, Mean diff: 0.00e+00

✓ Forward pass equivalence test PASSED!


       Shape |  Time (µs)
--------------------------
128x128x128 |      57.02
127x127x127 |     101.51
256x256x256 |      39.78
255x255x255 |      32.00
4096x4096x4096 |    6473.58
4095x4095x4095 |    7260.30
