In [1]:
import math
from typing import List, Optional, Tuple, Union

import datasets
import torch
import numpy as np
import torch.nn as nn
from datasets import load_dataset
import logging
import copy

from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM,
    # logger
)

from configuration_qwen2 import Qwen2Config

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [2]:
# from utils import are_tokenizers_same
# are_tokenizers_same(
#     paths = [
#         "/workspace/models/Arcee-VyLinh/",
#         "/workspace/models/Qwen2.5-Coder-3B/"
#     ]
# )

In [2]:
def load_layer(path, layer_idx=33):
	state_dict = {}
	shard_paths = [f for f in os.listdir(path) if f.endswith('.safetensors')]
	for shard_path in sorted(shard_paths, key=lambda x: int(x.split('-')[1])):
		apath = os.path.join(path, shard_path)
		with safe_open(apath, framework="pt", device="cpu") as f:
			for key in f.keys():
				if f"layers.{str(layer_idx)}." in key:
					state_dict[key] = f.get_tensor(key)
	return state_dict

def strip_prefix(state_dict, prefix="model.layers."):
    """Strips 'model.layers.*.' prefix from 'input_layernorm.weight' keys."""
    return {
      k.replace(f"{prefix}{k.split('.')[2]}.", "") if k.startswith(prefix)
      else k: v for k, v in state_dict.items()
    }

In [3]:
def lerp(
    t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
    return (1 - t) * v0 + t * v1

def weighted_sum(
    factors: List[float], 
    tensors: Union[List[np.ndarray], List[torch.Tensor]]
) -> Union[np.ndarray, torch.Tensor]:
    result = 0.0
    # for factor, tensor in zip(factors, tensors):
    #     result += factor * tensor
    return sum([tensor * factor for tensor, factor in zip(tensors, factors)])

def merge_modules(modules, factors):
    """
    This is only applicable for cases where a static set of scalars
    playing as merging factor for every submodules of the passed module.
    Not recommend for fine-grained usecases.
    """
    module_out = copy.deepcopy(modules[0])
    out_dict = module_out.state_dict()
    
    tensor_dicts_list = [m.state_dict() for m in modules]
    tensor_names = [key for key in tensor_dicts_list[0].keys()]
    
    for tensor_name in tensor_names:
        tensors_list = [tensor_dicts_list[i][tensor_name]
                       for i in range(len(modules))]
        tensor_computed = (
            weighted_sum(
                factors=factors,
                tensors=tensors_list
            )
            .to(tensors_list[0].dtype)
            .to(tensors_list[0].device)
        )
        out_dict[tensor_name] = tensor_computed
    module_out.load_state_dict(out_dict)
    return module_out

def merge_linears(modules, weight_factors, bias_factors):
    param_names = sorted([name for name, _ in modules[0].named_parameters()])
    for module in modules:
        other_param_names = sorted([name for name, _ in module.named_parameters()])
        assert param_names == other_param_names, "Mismatch tensor names."
        
    module_out = copy.deepcopy(modules[0])
    out_dict = module_out.state_dict()
    
    tensor_dicts_list = [m.state_dict() for m in modules]
    tensor_names = [key for key in tensor_dicts_list[0].keys()]
    
    for tensor_name in tensor_names:
        tensors_list = [tensor_dicts_list[i][tensor_name]
                       for i in range(len(modules))]
        if "weight" in tensor_name:
            factors = weight_factors
        elif "bias" in tensor_name:
            factors = bias_factors
        else:
            raise ValueError("Hey this tensor is neither weight or bias.")
            
        tensor_computed = (
            weighted_sum(
                factors=factors,
                tensors=tensors_list
            )
            .to(tensors_list[0].dtype)
            .to(tensors_list[0].device)
        )
        out_dict[tensor_name] = tensor_computed
    module_out.load_state_dict(out_dict)
    return module_out

In [4]:
def place_masks(target_module, ref_modules, factors):
    """
    Recursively replaces normal components with masked components.
    
    Args:
      module: The module in which to replace layers.
    """
    assert len(ref_modules) == len(factors)
    num_components = len(ref_modules)
    for name, target_child in target_module.named_children():
        ref_children = [getattr(module, name) for module in ref_modules]
        modes = ["scalar" for _ in ref_children]
        if isinstance(target_child, nn.Linear):
            new_module = LinearsWithMasks(
                linears=ref_children,
                weight_modes=["scalar"] * num_components,
                weight_values=factors,
                bias_modes=["scalar"] * num_components,
                bias_values=factors,
            )
            setattr(target_module, name, new_module)
        elif isinstance(target_child, nn.Embedding):
            setattr(target_module, name, EmbeddingsWithMasks(
                ref_children, modes, factors
            ))
        elif type(target_child).__name__ == Qwen2RMSNorm.__name__:
            # print("Hehe placing masks to a cutie RMSNorm")
            setattr(target_module, name, RMSNormsWithMasks(
                ref_children, modes, factors
            ))
        else:
            place_masks(target_child, ref_children, factors)

In [5]:
from transformers import GenerationConfig, TextStreamer
def generate(prompt, model, tokenizer, max_new_tokens=1024):
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model.device)
    model.eval()
    with torch.no_grad():
        generation_config = GenerationConfig(
            repetition_penalty=1.13,
            max_new_tokens=max_new_tokens,
            temperature=0.4,
            top_p=0.95,
            # top_k=20,
            # bos_token_id=tokenizer.bos_token_id,
            # eos_token_id=tokenizer.eos_token_id,
            # eos_token_id=0, # for open-end generation.
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            use_cache=True,
            return_dict_in_generate=True,
            output_attentions=False,
            output_hidden_states=False,
            output_scores=False,
        )
        streamer = TextStreamer(tokenizer, skip_prompt=True)
        generated = model.generate(
            inputs=input_ids,
            generation_config=generation_config,
            streamer=streamer,
        )
    gen_tokens = generated["sequences"].cpu()[:, len(input_ids[0]):]
    output = tokenizer.batch_decode(gen_tokens)[0]
    output = output.split(tokenizer.eos_token)[0]
    return output.strip()

def get_logits(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        logits = model(**input_ids).logits
    return logits

def get_hidden_states(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(**input_ids, output_hidden_states=True, use_cache=False)
    return outputs

In [6]:
def init_input(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    model.eval()
    
    with torch.no_grad():
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

    return dict(
        hidden_states=hidden_states,
        attention_mask=causal_mask,
        position_ids=position_ids,
        past_key_value=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )

In [7]:
def mlp_forward(mlp, x: torch.Tensor):
    """
    ref: self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
    """
    steps = {}
    steps.update({"step 0 (input)": x})
    
    gate = mlp.gate_proj(x)
    steps.update({"step 1 (gate)": gate})
    
    up = mlp.up_proj(x)
    steps.update({"step 2 (up)": up})
    
    act = mlp.act_fn(gate) # The activation function should be applied to the gate projection
    steps.update({"step 3 (activation)": act})
    
    act_up = act * up  # Multiply the activated gate with the up projection
    steps.update({"step 4 (act_up)": act_up})

    down = mlp.down_proj(act_up) # Apply the down projection to the result of act * up
    steps.update({"step 5 (down - output)": down})
    
    return dict(
        output=down,
        steps=steps
    )

# def mlp_forward(self, hidden_state):
#     return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

In [8]:
def decoder_forward(
    decoder,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

    steps = {}
    # logger.warning(f"-------- Logging hidden_states in decoder forward:")
    residual = hidden_states
    # logger.warning(f" hidden_states step 1 (as input): {hidden_states}")
    steps.update({"step 1": hidden_states})

    hidden_states = decoder.input_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 2 (after input_layernorm): {hidden_states}")
    steps.update({"step 2": hidden_states})
    # Self Attention
    hidden_states, self_attn_weights, present_key_value = decoder.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )
    # logger.warning(f" hidden_states step 3 (after self_attn): {hidden_states}")
    steps.update({"step 3": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 4 (after first skip connection): {hidden_states}")
    steps.update({"step 4": hidden_states})
    # Fully Connected
    residual = hidden_states
    hidden_states = decoder.post_attention_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 5 (after post_attention_layernorm): {hidden_states}")
    steps.update({"step 5": hidden_states})
    
    hidden_states = decoder.mlp(hidden_states)
    # logger.warning(f" hidden_states step 6 (after mlp): {hidden_states}")
    steps.update({"step 6": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 7 (after second skip connection): {hidden_states}")
    steps.update({"step 7": hidden_states})

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    outputs += (steps,)
    return outputs

In [9]:
def model_forward(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    #############
    
    model.eval()
    with torch.no_grad():

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        all_decoder_steps = ()

        for i, decoder_layer in enumerate(model.layers[:2]):   
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
          
            layer_outputs = decoder_forward(
                decoder_layer,
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]
            steps = layer_outputs[-1]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            all_decoder_steps += (steps,)

        hidden_states = model.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=(),
            hidden_states=all_hidden_states,
            attentions=all_decoder_steps,
        )

## modeling

In [11]:
class MaskConfig(PretrainedConfig):
    def __init__(
        self,
        mode: str = None,
        value: Union[float, torch.Tensor] = None,
        size: torch.Size = None,
        **kwargs,
    ):
        self.mode = mode
        self.value = value
        self.size = size
        super().__init__(**kwargs)

class Mask(nn.Module):
    def __init__(
        self, 
        mask_config: MaskConfig
    ):
        super().__init__()
        """
        now only support mode == scalar
        """
        self.mode = mask_config.mode
        if mask_config.mode == "scalar":
            value = mask_config.value if mask_config.value is not None else 1
            self.weight = nn.Parameter(torch.tensor(value))
        else:
            raise ValueError(f"Unsupported mask mode: {mask_config.mode}")
            
        self.size = mask_config.size ## Full size of the mask after broadcast.
        if self.size is not None:
            try:
                self.weight * torch.rand(self.size)
            except RuntimeError:
                print("mask initialized with an incompatible shape.")

    def forward(self, x):
        """
        Be really careful here (though I do not think it matters that much),
        When testing, it's important that the masking operation is implemented
        with `x = self.weight * x` instead of `x = x * self.weight`.

        Neither of those two implementation is superior, however I need to be
        consistent when doing testing because the phenonmenon above could lead
        to some number imprecision, which may fail `torch.testing.assert_close`
        """
        if self.size is None:
            return self.weight * x
        else:
            if self.size != x.shape:
                print("The shape of input does not match that of the mask.")
            return self.weight * x

In [12]:
class LinearWithMask(nn.Module):
    def __init__(self, linear, weight_mask_config: MaskConfig, bias_mask_config: MaskConfig = None):
        super().__init__()
        self.linear = linear
        self.weight_mask_config = weight_mask_config
        self.bias_mask_config = bias_mask_config

        if linear.weight.shape != weight_mask_config.size:
            print("Weight mask shape is not compatible with linear, reinitializing...")
            self.weight_mask_config.size = linear.weight.shape
        self.weight_mask = Mask(self.weight_mask_config)

        if linear.bias is not None and bias_mask_config is not None:
            if linear.bias.shape != bias_mask_config.size:
                print("Bias mask shape is not compatible with linear, reinitializing...")
                self.bias_mask_config.size = linear.bias.shape
            self.bias_mask = Mask(self.bias_mask_config)
        else:
            self.bias_mask = None

    def forward(self, x):
        masked_weight = self.weight_mask(self.linear.weight)
        if self.linear.bias is not None and self.bias_mask is not None:
            masked_bias = self.bias_mask(self.linear.bias)
        else:
            masked_bias = self.linear.bias
        return nn.functional.linear(x, masked_weight, masked_bias)

class LinearsWithMasks(nn.Module):
    def __init__(
        self,
        linears: List[nn.Module],
        weight_modes: List[str] = ["scalar"],
        weight_values: List[float] = None,
        bias_modes: List[str] = ["scalar"],
        bias_values: List[float] = None,
    ):
        super().__init__()
        
        if not all(isinstance(linear, nn.Linear) for linear in linears):
            raise ValueError("All elements in 'linears' must be instances of nn.Linear.")

        weight_sizes = [linear.weight.shape for linear in linears]
        bias_sizes = [linear.bias.shape if linear.bias is not None else None for linear in linears]
        
        if weight_values is None or len(weight_values) != len(linears):
            raise ValueError(f"weight_values for masks: {weight_values} do not match with linear layers: {linears}")
        if bias_values is None:
            bias_values = [None] * len(linears)
        if len(bias_values) != len(linears):
            raise ValueError(f"bias_values for masks: {bias_values} do not match with linear layers: {linears}")

        weight_mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(weight_modes, weight_values, weight_sizes)
        ]
        bias_mask_configs = [
            MaskConfig(mode, value, size) if size is not None else None
            for mode, value, size in zip(bias_modes, bias_values, bias_sizes)
        ]

        self.masked_linears = nn.ModuleList(
            [LinearWithMask(linear, weight_mask_config, bias_mask_config)
             for linear, weight_mask_config, bias_mask_config 
             in zip(linears, weight_mask_configs, bias_mask_configs)]
        )

    def forward(self, x):
        weights = [linear.weight_mask(linear.linear.weight) 
                   for linear in self.masked_linears]
        # merged_weight = torch.sum(torch.stack(weights), dim=0)
        merged_weight = sum(weights)

        biases = [
            linear.bias_mask(linear.linear.bias) 
            if linear.linear.bias is not None and linear.bias_mask is not None 
            else linear.linear.bias for linear in self.masked_linears
        ]
        
        if all(b is None for b in biases):
            merged_bias = None
        else:
            biases = [
                b if b is not None
                else torch.zeros_like(weights[0][:, 0])
                for b in biases
            ]
            merged_bias = sum(biases)

        return nn.functional.linear(x, merged_weight, merged_bias)

In [13]:
class RMSNormWithMask(nn.Module):
    def __init__(self, rms_norm: Qwen2RMSNorm, mask_config: MaskConfig):
        super().__init__()
        self.rms_norm = rms_norm
        self.mask_config = mask_config
        if rms_norm.weight.shape != mask_config.size:
            print("Mask shape is not compatible with RMSNorm, reinitializing...")
        self.mask_config.size = rms_norm.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, hidden_states):
        masked_weight = self.mask(self.rms_norm.weight)
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.rms_norm.variance_epsilon)
        return masked_weight * hidden_states.to(input_dtype)

class RMSNormsWithMasks(nn.Module):
    def __init__(
        self,
        rms_norms: List[Qwen2RMSNorm],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [rms_norm.weight.shape for rms_norm in rms_norms]
        if values is None or len(values) != len(rms_norms):
            raise ValueError(f"values for masks: {values} do not match with RMSNorm layers: {rms_norms}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_rms_norms = nn.ModuleList(
            [RMSNormWithMask(rms_norm, mask_config)
             for rms_norm, mask_config in zip(rms_norms, mask_configs)]
        )

    def forward(self, hidden_states):
        weights = [rms.mask(rms.rms_norm.weight) for rms in self.masked_rms_norms]
        merged_weight = sum(weights)
        variance_epsilon = self.masked_rms_norms[0].rms_norm.variance_epsilon
        for rms in self.masked_rms_norms:
            assert variance_epsilon == rms.rms_norm.variance_epsilon, "Variance epsilon among models must be consistent"
        
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
        return merged_weight * hidden_states.to(input_dtype)

In [14]:
class EmbeddingWithMask(nn.Module):
    def __init__(self, embedding: nn.Embedding, mask_config: MaskConfig):
        super().__init__()
        self.embedding = embedding
        self.mask_config = mask_config
        if embedding.weight.shape != mask_config.size:
            print("Mask shape is not compatible with Embedding, reinitializing...")
        self.mask_config.size = embedding.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, input_ids):
        masked_weight = self.mask(self.embedding.weight)
        return nn.functional.embedding(
            input_ids,
            masked_weight,
            padding_idx=self.embedding.padding_idx,
            max_norm=self.embedding.max_norm,
            norm_type=self.embedding.norm_type,
            scale_grad_by_freq=self.embedding.scale_grad_by_freq,
            sparse=self.embedding.sparse,
        )

class EmbeddingsWithMasks(nn.Module):
    def __init__(
        self,
        embeddings: List[nn.Embedding],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [embedding.weight.shape for embedding in embeddings]
        if values is None or len(values) != len(embeddings):
            raise ValueError(f"values for masks: {values} do not match with Embedding layers: {embeddings}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_embeddings = nn.ModuleList(
            [EmbeddingWithMask(embedding, mask_config)
             for embedding, mask_config in zip(embeddings, mask_configs)]
        )

    def forward(self, input_ids):
        weights = [emb.mask(emb.embedding.weight) for emb in self.masked_embeddings]
        merged_weight = sum(weights)
        an_embedding = self.masked_embeddings[0].embedding
        for other in self.masked_embeddings:
            other_embedding = other.embedding
            assert an_embedding.padding_idx == other_embedding.padding_idx
            assert an_embedding.max_norm == other_embedding.max_norm
            assert an_embedding.norm_type == other_embedding.norm_type
            assert an_embedding.scale_grad_by_freq == other_embedding.scale_grad_by_freq
            assert an_embedding.sparse == other_embedding.sparse
            
        return nn.functional.embedding(
            input_ids,
            merged_weight,
            padding_idx=an_embedding.padding_idx,
            max_norm=an_embedding.max_norm,
            norm_type=an_embedding.norm_type,
            scale_grad_by_freq=an_embedding.scale_grad_by_freq,
            sparse=an_embedding.sparse,
        )

In [15]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)

class Merger(PreTrainedModel):
    def __init__(self, merge_config):
        super().__init__(merge_config)
        """
        Need to check whether models are mergeable (having some sort of the same config)
        """
        self.merge_config = merge_config
        self.num_models = len(merge_config.model_paths)
        self.configs = [
            # AutoConfig.from_pretrained(path) 
            Qwen2Config.from_pretrained(path)
            for path in merge_config.model_paths
        ]
        # self.merger = Qwen2ForCausalLM(self.config)
        self.models = nn.ModuleList([
            Qwen2ForCausalLM.from_pretrained(
            # AutoModelForCausalLM.from_pretrained(
                merge_config.model_paths[i], 
                config=self.configs[i],
                torch_dtype=torch.bfloat16
            ) 
            for i in range(self.num_models)
        ])
        # self.__post_init__(merge_config)
        
    def __post_init__(self, merge_config, factors):
        # dummy_config = copy.deepcopy(self.configs[0])
        # dummy_config.update({"hidden_size": 1, "intermediate_size": 1})
        # self.merger = AutoModelForCausalLM.from_config(dummy_config)
        self.merger = copy.deepcopy(self.models[0])
        place_masks(self.merger, self.models, factors=factors)
        
    def forward(self, tensor, labels=None):
        pass

In [16]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)
merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [17]:
merger = Merger(merge_config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
merger = merger.to("cuda:0")

In [19]:
merger.__post_init__(merge_config, factors=[0.0, 1.0])

In [20]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])

In [21]:
system = "You are a helpful assistant."
prompt = "Continue this text: A dog is a cat"
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
text

'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nContinue this text: A dog is a cat<|im_end|>\n<|im_start|>assistant\n'

In [23]:
answer = generate(text, merger.merger, tokenizer, max_new_tokens=100)



A dog is indeed a type of animal, specifically a mammal belonging to the Canidae family. Dogs have been domesticated for thousands of years and serve various purposes such as companionship, hunting, herding, protection, and assistance in tasks like guide dogs or service animals.

Dogs come in different breeds with varying physical characteristics, sizes, colors, and temperaments. Some common types include:

1. Labrador Retrievers - known for their friendly nature, intelligence, and love of water.



From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


In [24]:
answer = generate(text, merger.models[1], tokenizer, max_new_tokens=100)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


A dog is indeed a type of animal, specifically a mammal belonging to the Canidae family. Dogs have been domesticated for thousands of years and serve various purposes such as companionship, hunting, herding, protection, and assistance in tasks like guide dogs or service animals.

Dogs come in different breeds with varying physical characteristics, sizes, colors, and temperaments. Some common types include:

1. Labrador Retrievers - Known for their friendly nature, intelligence, and love for water.



In [20]:
# embeddings_with_masks = merger.merger.model.embed_tokens
# embedding1 = merger.models[1].model.embed_tokens

In [21]:
# torch.manual_seed(42)
# num_embeddings = embedding1.weight.data.shape[0]
# device = merger.device
# input_ids = torch.randint(0, num_embeddings, (2, 5)).to(device=device)  # Example input_ids
# input_ids

In [22]:
# o_merged = embeddings_with_masks(input_ids)
# o1 = embedding1(input_ids)
# torch.allclose(o_merged, o1, rtol=1e-10, atol=1e-10)

In [23]:
logits_merged = get_logits(text, merger.merger, tokenizer)
logits1 = get_logits(text, merger.models[1], tokenizer)

In [24]:
logits_merged, logits1

(tensor([[[14.3750, 11.8750, 12.8125,  ...,  2.4062,  2.4062,  2.4062],
          [11.9375, 14.3125, 13.6875,  ...,  2.6875,  2.6875,  2.6875],
          [15.0000, 12.9375, 17.5000,  ...,  1.6719,  1.6719,  1.6719],
          ...,
          [ 8.2500,  8.2500,  6.3750,  ..., -0.9414, -0.9414, -0.9414],
          [10.4375, 12.5625,  9.8750,  ...,  1.9766,  1.9766,  1.9766],
          [ 8.5625, 10.4375,  7.0938,  ...,  0.0248,  0.0248,  0.0248]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[14.3750, 11.8750, 12.8125,  ...,  2.4062,  2.4062,  2.4062],
          [11.9375, 14.3125, 13.6875,  ...,  2.6875,  2.6875,  2.6875],
          [15.0000, 12.9375, 17.5000,  ...,  1.6719,  1.6719,  1.6719],
          ...,
          [ 8.2500,  8.2500,  6.3750,  ..., -0.9414, -0.9414, -0.9414],
          [10.4375, 12.5625,  9.8750,  ...,  1.9766,  1.9766,  1.9766],
          [ 8.5625, 10.4375,  7.0938,  ...,  0.0248,  0.0248,  0.0248]]],
        device='cuda:0', dtype=torch.bfloat16))

In [25]:
outputs_merged = get_hidden_states(text, merger.merger, tokenizer)
outputs1 = get_hidden_states(text, merger.models[1], tokenizer)

In [26]:
outputs_test_merged = model_forward(text, merger.merger.model, tokenizer)

In [27]:
outputs_test_1 = model_forward(text, merger.models[1].model, tokenizer)

In [28]:
for j, layer_output in enumerate(outputs_test_merged.attentions):
    other_output = outputs_test_1.attentions[j]
    for i in range(7):
        key = f"step {i+1}"
        if torch.allclose(layer_output[key], other_output[key], atol=0, rtol=0):
            print(f"layer {j}, step {i+1} passed!")
        else:
            print(f"FAIL AT layer {j}, step {i+1}")

layer 0, step 1 passed!
layer 0, step 2 passed!
layer 0, step 3 passed!
layer 0, step 4 passed!
layer 0, step 5 passed!
layer 0, step 6 passed!
layer 0, step 7 passed!
layer 1, step 1 passed!
layer 1, step 2 passed!
layer 1, step 3 passed!
layer 1, step 4 passed!
layer 1, step 5 passed!
layer 1, step 6 passed!
layer 1, step 7 passed!


In [34]:
decoder_merged = merger.merger.model.layers[0]
decoder_1 = merger.models[1].model.layers[0]

In [35]:
model_inputs = init_input(text, merger.merger.model, tokenizer)

In [36]:
out_dec_merged = decoder_merged(
    model_inputs['hidden_states'], 
    position_embeddings=model_inputs["position_embeddings"]
)
out_dec_1 = decoder_1(
    model_inputs['hidden_states'], 
    position_embeddings=model_inputs["position_embeddings"]
)

In [37]:
torch.testing.assert_close(out_dec_merged, out_dec_1, atol=0, rtol=0)

AssertionError: Tensor-likes are not equal!

Mismatched elements: 27697 / 57344 (48.3%)
Greatest absolute difference: 0.03125 at index (0, 2, 1874)
Greatest relative difference: inf at index (0, 0, 533)

The failure occurred for item [0]

In [38]:
out_mlp_merged = mlp_forward(decoder_merged.mlp, model_inputs["hidden_states"])
out_mlp_1 = mlp_forward(decoder_1.mlp, model_inputs["hidden_states"])
torch.testing.assert_close(out_mlp_1["output"], out_mlp_merged["output"], atol=0, rtol=0)

AssertionError: Tensor-likes are not equal!

Mismatched elements: 38040 / 57344 (66.3%)
Greatest absolute difference: 3.0517578125e-05 at index (0, 1, 71)
Greatest relative difference: inf at index (0, 1, 926)

In [39]:
step_keys = sorted(out_mlp_merged["steps"].keys())
print(" -> ".join(step_keys))
print("---" * 10)
for key in step_keys:
    tensor_1 = out_mlp_1["steps"][key]
    tensor_merged = out_mlp_merged["steps"][key]
    if torch.allclose(tensor_1, tensor_merged, atol=0, rtol=0):
        print(f"{key} passed!")
    else:
        print(f"FAIL AT {key}")

step 0 (input) -> step 1 (gate) -> step 2 (up) -> step 3 (activation) -> step 4 (act_up) -> step 5 (down - output)
------------------------------
step 0 (input) passed!
FAIL AT step 1 (gate)
FAIL AT step 2 (up)
FAIL AT step 3 (activation)
FAIL AT step 4 (act_up)
FAIL AT step 5 (down - output)


In [40]:
decoder_1.mlp.gate_proj(model_inputs["hidden_states"])

tensor([[[-0.0179,  0.0153,  0.0164,  ..., -0.0186, -0.0030,  0.0052],
         [ 0.0194, -0.0164,  0.0254,  ...,  0.0014, -0.0325,  0.0043],
         [ 0.0347, -0.0330, -0.0273,  ...,  0.0630,  0.0035, -0.0332],
         ...,
         [-0.0179,  0.0153,  0.0164,  ..., -0.0186, -0.0030,  0.0052],
         [ 0.0344,  0.0045,  0.0212,  ...,  0.0070,  0.0222,  0.0010],
         [ 0.0347, -0.0330, -0.0273,  ...,  0.0630,  0.0035, -0.0332]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)

In [41]:
decoder_merged.mlp.gate_proj(model_inputs["hidden_states"])

tensor([[[-0.0179,  0.0153,  0.0162,  ..., -0.0186, -0.0031,  0.0052],
         [ 0.0194, -0.0164,  0.0254,  ...,  0.0014, -0.0325,  0.0042],
         [ 0.0349, -0.0327, -0.0273,  ...,  0.0630,  0.0035, -0.0332],
         ...,
         [-0.0179,  0.0153,  0.0162,  ..., -0.0186, -0.0031,  0.0052],
         [ 0.0344,  0.0045,  0.0212,  ...,  0.0070,  0.0221,  0.0011],
         [ 0.0349, -0.0327, -0.0273,  ...,  0.0630,  0.0035, -0.0332]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>)

In [42]:
decoder_merged.mlp.gate_proj.forward??

[0;31mSignature:[0m [0mdecoder_merged[0m[0;34m.[0m[0mmlp[0m[0;34m.[0m[0mgate_proj[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Define the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
[0;31mSource:[0m   
    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mweights[0m [0;34m=[0m [0;34m[[0m[0mlinear[0m[0;34m.[0m[0mweight_mask[0m[0;34m([0m[0mlinear[0m[0;34m.[0m[0mlinear[0m[0;34m.[0m[0mweight[0m[0;34m)[0m [0;34m[0m
[0;34m[0m                   [0;32mfor[0m [0mlinear[0m [0;32min[0m [0mself

In [48]:
def get_weight_and_bias(masked):
    weights = [linear.weight_mask(linear.linear.weight) 
               for linear in masked.masked_linears]
    # merged_weight = torch.sum(torch.stack(weights), dim=0)
    merged_weight = sum(weights)
    biases = [
        linear.bias_mask(linear.linear.bias) 
        if linear.linear.bias is not None and linear.bias_mask is not None 
        else linear.linear.bias for linear in masked.masked_linears
    ]
    if all(b is None for b in biases):
        merged_bias = None
    else:
        biases = [
            b if b is not None
            else torch.zeros_like(weights[0][:, 0])
            for b in biases
        ]
        merged_bias = sum(biases)


    return dict(
        weight=merged_weight,
        bias=merged_bias
    )

In [49]:
a_masked_linears_params = get_weight_and_bias(decoder_merged.mlp.gate_proj)

In [50]:
torch.testing.assert_close(a_masked_linears_params["weight"], decoder_1.mlp.gate_proj.weight.data)

In [51]:
assert a_masked_linears_params["bias"] == decoder_1.mlp.gate_proj.bias

In [29]:
big_bounty = """
Damn there is something wrong at step 6 decoder.
Which is the MLP :D
I'll fix it tmr.
-------------------
The error is fixed now. nn.Linear is kinda weird, which results in
some number instability. In short, bias=torch.zeros_like() is different
from bias=None when initializing an nn.Linear.

In specific, in older implementations, I handle biases in LinearsWithMasks 
like this:
```
def forward(...):
    ...
    biases = [
        linear.bias_mask(linear.linear.bias) 
        if linear.linear.bias is not None and linear.bias_mask is not None 
        else linear.linear.bias for linear in self.masked_linears
    ]
    biases = [
        b if b is not None
        else torch.zeros_like(weights[0][:, 0])
        for b in biases
    ]
    # merged_bias = torch.sum(torch.stack(biases))
    merged_bias = sum(biases)
    if all(b is None for b in biases):
        merged_bias = None
    ...
```
This implementation is not accurate as it will always assign torch.zeros_like()
to biases when they are supposed to be None, making the last `if` redundant.

A quick and accurate fix should be:
```
def forward(...):
    ...
    biases = [
        linear.bias_mask(linear.linear.bias) 
        if linear.linear.bias is not None and linear.bias_mask is not None 
        else linear.linear.bias for linear in self.masked_linears
    ]
    
    if all(b is None for b in biases):
        merged_bias = None
    else:
        biases = [
            b if b is not None
            else torch.zeros_like(weights[0][:, 0])
            for b in biases
        ]
        merged_bias = sum(biases)
    ...
```
Yeah, this is it.
"""

In [30]:
debug_strategy = """
One of the cool things I figured out myself is how to do surgeon to class methods.
Instead of having to re-implement a class with a modified method, like adding
logging statements between lines of codes.

For example, with the class Qwen2MLP, instead of writing a new class Qwen2MLPDebug
with a new forward method when I want to inspect intermediate steps within the 
forward method of Qwen2MLP, I could just write a dedicated function for debugging:

def mlp_forward(mlp, *args, **kwargs):
    ```
    this function is a copy of Qwen2MLP.forward, but:
    - the input and output signatures largely reflect what Qwen2MLP.forward has,
    however, I can plant additional debugging flags and logging statements.
    - change `self` to `mlp`. not world-changing. you can also decide to let 
    `self` in the mlp_forward's input signature, which will save you some minutes
    from replacing `self` in Qwen2MLP.forward to `mlp` like I do.
    ```
    ...

This function is exactly the Qwen2MLP.forward method, however I replace self -> mlp,
with mlp is an instance of Qwen2MLP. I know I know this is obvious, but by applying 
the pattern for debugging, it saves so much time (do not have write a class from 
scratch and import everything), ensures correctness (it's a copy of an established 
class method duh), and most importantly leaves the existing classes and methods 
intact (extremely helpful when experimenting, because I don't have to restart 
everything in my notebook to reflect a minor change in my debugging code, which 
would have been the case if I decided to implement a Qwen2MLPDebug class).
-----------
In this notebook, I've implemented 3 debugging function with the principle above,
which has helped me localize the bug causing inconsistencies between outputs of
merged_model = merge(models=[model1, model2], factors=[1.0, 0.0]) and model1.
These two outputs should be identical, but due to a weird feature of torch, this 
properties did not hold.

In retrorespect, here is what happened:
- Inspected logits of `merged_model` and `model1`, expected them to be identical.
  They did not.
- Implemented mlp_forward to inspect at what part of Qwen2ForCausalLM's forward pass 
did the intermediate states diverged between two models.
- Outputs of embed_tokens layer OKAY, they were identical.
- Outputs of the first decoder layer failed.
- Implemented decoder_forward to insptect at what step it failed (input norm, 
two skip connections, self attn, mlp, or output norm)
- It was mlp that failed the consistency test. What happened? I did unittests at 
`debug_1.ipynb`, might missed something.
- Implemented mlp_forward.
- It was the gate_proj and up_proj that failed, both of which are nn.Linear with no
bias. This indicated that my LinearsWithMasks implementation was not correct.
- It was indeed not correct. Having inspected further, I found that gate_proj of a
MLP within the merged_model had bias=torch.zeros_like(), while it should be None 
instead. Theoretically W * X should be IDENTICAL to W * X + torch.zeros_like(n_columns).
But it's not. This weird fuck bug.
- Fixed the forward pass within the LinearsWithMasks class. Everything OKAY now. 
Fully functional scalar merging.
"""