In [1]:
import math
from typing import List, Optional, Tuple, Union

import datasets
import torch
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM,
)

from configuration_qwen2 import Qwen2Config

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [25]:
in_features = 24
out_features = 50
lin = nn.Linear(in_features, out_features, bias=False)

In [43]:
torch.manual_seed(42)
mask = torch.rand(1, in_features) 
mask

tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936, 0.9408,
         0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294, 0.8854, 0.5739,
         0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317]])

In [44]:
mask * lin.weight

tensor([[-0.1422, -0.0861, -0.0221,  ..., -0.0394,  0.0350,  0.0304],
        [ 0.0915, -0.1138, -0.0774,  ..., -0.0537,  0.0545,  0.0566],
        [ 0.1733, -0.1541, -0.0775,  ...,  0.0612, -0.0440, -0.0906],
        ...,
        [ 0.0543,  0.0347, -0.0459,  ...,  0.0071,  0.0598,  0.0986],
        [ 0.1008,  0.0748,  0.0605,  ..., -0.0351, -0.0046, -0.0413],
        [-0.0910, -0.0593, -0.0283,  ...,  0.0228,  0.0181, -0.1110]],
       grad_fn=<MulBackward0>)

In [2]:
from utils import are_tokenizers_same
are_tokenizers_same(
    paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

2024-12-16 10:24:20,672 - INFO - Comparing tokenizer at /workspace/models/Arcee-VyLinh/ with tokenizer at /workspace/models/Qwen2.5-Coder-3B/
2024-12-16 10:24:20,675 - INFO - Tokenizer at /workspace/models/Arcee-VyLinh/ and /workspace/models/Qwen2.5-Coder-3B/ are the same based on the defined criteria


True

In [2]:
def load_layer(path, layer_idx=33):
	state_dict = {}
	shard_paths = [f for f in os.listdir(path) if f.endswith('.safetensors')]
	for shard_path in sorted(shard_paths, key=lambda x: int(x.split('-')[1])):
		apath = os.path.join(path, shard_path)
		with safe_open(apath, framework="pt", device="cpu") as f:
			for key in f.keys():
				if f"layers.{str(layer_idx)}." in key:
					state_dict[key] = f.get_tensor(key)
	return state_dict

def strip_prefix(state_dict, prefix="model.layers."):
    """Strips 'model.layers.*.' prefix from 'input_layernorm.weight' keys."""
    return {
      k.replace(f"{prefix}{k.split('.')[2]}.", "") if k.startswith(prefix)
      else k: v for k, v in state_dict.items()
    }

In [3]:
def lerp(
    t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
    return (1 - t) * v0 + t * v1

def weighted_sum(
    factors: List[float], 
    tensors: Union[List[np.ndarray], List[torch.Tensor]]
) -> Union[np.ndarray, torch.Tensor]:

    return sum([tensor * factor for tensor, factor in zip(tensors, factors)])

def merge_tensors(modules, weight_factors, bias_factors):
    param_names = sorted([name for name, _ in modules[0].named_parameters()])
    for module in modules:
        other_param_names = sorted([name for name, _ in module.named_parameters()])
        assert param_names == other_param_names, "Mismatch tensor names."
        
    module_out = copy.deepcopy(modules[0])
    out_dict = module_out.state_dict()
    
    tensor_dicts_list = [m.state_dict() for m in modules]
    tensor_names = [key for key in tensor_dicts_list[0].keys()]
    
    for tensor_name in tensor_names:
        tensors_list = [tensor_dicts_list[i][tensor_name]
                       for i in range(len(modules))]
        if "weight" in tensor_name:
            factors = weight_factors
        elif "bias" in tensor_name:
            factors = bias_factors
        else:
            raise ValueError("Hey this tensor is neither weight or bias.")
            
        tensor_computed = (
            weighted_sum(
                factors=factors,
                tensors=tensors_list
            )
            .to(tensors_list[0].dtype)
            .to(tensors_list[0].device)
        )
        out_dict[tensor_name] = tensor_computed
    module_out.load_state_dict(out_dict)
    return module_out

In [4]:
def init_factors(components: List[torch.Tensor], strategy="naive"):
    if strategy == "naive":
        n_components = len(components) 
        random_floats = np.random.rand(n_components)
        normalized_floats = random_floats / np.sum(random_floats)
        factors = normalized_floats.tolist()
    elif strategy == "slerp":
        raise ValueError(f"Initialization strategy {strategy} has not been implemented.")
        if len(components) != 2:
            raise ValueError(f"Initialization strategy {strategy.upper()} only works for 2 components.")
    else:
        raise ValueError(f"Initialization strategy {strategy} has not been implemented.")

    return factors

def find_modules_to_add_masks(target_module):
    module_names_to_replace = []
    for parent_name, parent_module in target_module.named_modules():
        for name, child in parent_module.named_children():
            full_child_name = f"{parent_name}.{name}" if parent_name else name
            if isinstance(child, (nn.Linear, nn.Embedding)) or "RMSNorm" in type(child).__name__:
                module_names_to_replace.append(full_child_name)

    return module_names_to_replace

def initialize_masks(target_module, ref_modules, strategy="naive"):
    """
    Replaces eligible submodules in target_module with masked versions, 
    using corresponding modules from ref_modules as a reference for weights.

    Args:
        target_module: The module in which to replace submodules.
        ref_modules: A list of modules to use as a reference for weights.
        strategy: The initialization strategy for factors ("naive" or others to be implemented).
    """
    module_names_to_replace = find_modules_to_add_masks(target_module)
    
    for module_name in tqdm(module_names_to_replace, desc="Initializing masks"):
        module_names = module_name.split(".")
        target_child = target_module
        ref_children = ref_modules

        for m_name in module_names:
            target_child = getattr(target_child, m_name)
            ref_children = [getattr(ref_module, m_name) for ref_module in ref_children]

        num_components = len(ref_modules)
        modes = ["scalar" for _ in ref_children]

        if isinstance(target_child, nn.Linear):
            weights = [ref.weight.data for ref in ref_children]
            biases = [ref.bias.data if ref.bias is not None else None for ref in ref_children]

            weight_factors = init_factors(weights, strategy=strategy)
            bias_factors = init_factors(biases, strategy=strategy)

            new_module = LinearsWithMasks(
                linears=ref_children,
                weight_modes=modes,
                weight_values=weight_factors,
                bias_modes=modes,
                bias_values=bias_factors,
            )

        elif isinstance(target_child, nn.Embedding):
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            new_module = EmbeddingsWithMasks(ref_children, modes, factors)

        elif "RMSNorm" in type(target_child).__name__:
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            new_module = RMSNormsWithMasks(ref_children, modes, factors)

        # Replace the original module with the new masked module
        parent_module = target_module
        for m_name in module_names[:-1]:
            parent_module = getattr(parent_module, m_name)
        setattr(parent_module, module_names[-1], new_module)


def initialize_masks_recursive(target_module, ref_modules, strategy="naive"):
    """
    Recursively replaces normal components with masked components.
    
    Args:
      module: The module in which to replace layers.
    """
    num_components = len(ref_modules)
    for name, target_child in target_module.named_children():
        ref_children = [getattr(module, name) for module in ref_modules]
        modes = ["scalar" for _ in ref_children]
        if isinstance(target_child, nn.Linear):
            weights = [ref.weight.data for ref in ref_children]
            biases = [ref.bias.data if ref.bias is not None else None 
                      for ref in ref_children]
            
            weight_factors = init_factors(weights, strategy=strategy)
            bias_factors = init_factors(biases, strategy=strategy)
            
            new_module = LinearsWithMasks(
                linears=ref_children,
                weight_modes=modes,
                weight_values=weight_factors,
                bias_modes=modes,
                bias_values=bias_factors,
            )
            setattr(target_module, name, new_module)
            
        elif isinstance(target_child, nn.Embedding):
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            setattr(target_module, name, EmbeddingsWithMasks(
                ref_children, modes, factors
            ))
            
        elif "RMSNorm" in type(target_child).__name__:
            weights = [ref.weight.data for ref in ref_children]
            factors = init_factors(weights, strategy=strategy)
            setattr(target_module, name, RMSNormsWithMasks(
                ref_children, modes, factors
            ))
            
        else:
            initialize_masks(target_child, ref_children, strategy=strategy)

In [5]:
def free_memory(logger=None):
    """Frees GPU memory and logs memory usage before and after.

    Args:
        logger: An optional logging.Logger instance to use for logging.
                If None, a default logger will be created.
    """

    if logger is None:
        # Create a default logger if one is not provided
        logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        logger = logging.getLogger(__name__)

    if not torch.cuda.is_available():
        logger.info("CUDA is not available. No GPU memory to free.")
        return

    initial_memory = torch.cuda.memory_allocated()
    logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")

    # Force garbage collection
    gc.collect()

    # Empty PyTorch's cache
    torch.cuda.empty_cache()

    final_memory = torch.cuda.memory_allocated()
    logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")

    freed_memory = initial_memory - final_memory
    logger.info(f"Freed GPU memory: {freed_memory / 1024**3:.2f} GB")

In [8]:
from transformers import GenerationConfig, TextStreamer
def generate(prompt, model, tokenizer, max_new_tokens=1024):
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model.device)
    model.eval()
    with torch.no_grad():
        generation_config = GenerationConfig(
            repetition_penalty=1.13,
            max_new_tokens=max_new_tokens,
            temperature=0.4,
            top_p=0.95,
            # top_k=20,
            # bos_token_id=tokenizer.bos_token_id,
            # eos_token_id=tokenizer.eos_token_id,
            # eos_token_id=0, # for open-end generation.
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            use_cache=True,
            return_dict_in_generate=True,
            output_attentions=False,
            output_hidden_states=False,
            output_scores=False,
        )
        streamer = TextStreamer(tokenizer, skip_prompt=True)
        generated = model.generate(
            inputs=input_ids,
            generation_config=generation_config,
            streamer=streamer,
        )
    gen_tokens = generated["sequences"].cpu()[:, len(input_ids[0]):]
    output = tokenizer.batch_decode(gen_tokens)[0]
    output = output.split(tokenizer.eos_token)[0]
    return output.strip()

def get_logits(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        logits = model(**input_ids).logits
    return logits

def get_hidden_states(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(**input_ids, output_hidden_states=True, use_cache=False)
    return outputs

## debugging utils

In [8]:
def init_input(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    model.eval()
    
    with torch.no_grad():
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

    return dict(
        hidden_states=hidden_states,
        attention_mask=causal_mask,
        position_ids=position_ids,
        past_key_value=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )

In [9]:
def mlp_forward(mlp, x: torch.Tensor):
    """
    ref: self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
    """
    steps = {}
    steps.update({"step 0 (input)": x})
    
    gate = mlp.gate_proj(x)
    steps.update({"step 1 (gate)": gate})
    
    up = mlp.up_proj(x)
    steps.update({"step 2 (up)": up})
    
    act = mlp.act_fn(gate) # The activation function should be applied to the gate projection
    steps.update({"step 3 (activation)": act})
    
    act_up = act * up  # Multiply the activated gate with the up projection
    steps.update({"step 4 (act_up)": act_up})

    down = mlp.down_proj(act_up) # Apply the down projection to the result of act * up
    steps.update({"step 5 (down - output)": down})
    
    return dict(
        outputs=down,
        debugging=steps
    )

# def mlp_forward(self, hidden_state):
#     return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

In [10]:
def decoder_forward(
    decoder,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

    steps = {}
    # logger.warning(f"-------- Logging hidden_states in decoder forward:")
    residual = hidden_states
    # logger.warning(f" hidden_states step 1 (as input): {hidden_states}")
    steps.update({"step 1": hidden_states})

    hidden_states = decoder.input_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 2 (after input_layernorm): {hidden_states}")
    steps.update({"step 2": hidden_states})
    # Self Attention
    hidden_states, self_attn_weights, present_key_value = decoder.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )
    # logger.warning(f" hidden_states step 3 (after self_attn): {hidden_states}")
    steps.update({"step 3": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 4 (after first skip connection): {hidden_states}")
    steps.update({"step 4": hidden_states})
    # Fully Connected
    residual = hidden_states
    hidden_states = decoder.post_attention_layernorm(hidden_states)
    # logger.warning(f" hidden_states step 5 (after post_attention_layernorm): {hidden_states}")
    steps.update({"step 5": hidden_states})
    
    hidden_states = decoder.mlp(hidden_states)
    # logger.warning(f" hidden_states step 6 (after mlp): {hidden_states}")
    steps.update({"step 6": hidden_states})
    
    hidden_states = residual + hidden_states
    # logger.warning(f" hidden_states step 7 (after second skip connection): {hidden_states}")
    steps.update({"step 7": hidden_states})

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return dict(
        outputs=outputs,
        debugging=steps
    )

In [11]:
def model_forward(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None
    cache_position = None
    position_ids = None
    output_hidden_states = True
    output_attentions = False
    use_cache = False
    return_dict = True
    #############
    
    model.eval()
    with torch.no_grad():

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        inputs_embeds = model.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = model.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        all_decoder_steps = {}

        for i, decoder_layer in enumerate(model.layers[:2]):   
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
          
            layer_outputs = decoder_forward(
                decoder_layer,
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]
            steps = layer_outputs[-1]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            all_decoder_steps.update({f"layer {i}": steps})

        hidden_states = model.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        outputs = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=(),
            hidden_states=all_hidden_states,
            attentions=()
        )
        return dict(
            outputs=outputs,
            debugging=steps
        )

## modeling

In [16]:
plan = """
- add vector mask
- add slerp mask
- add mask manager (imposing some constraints on mask.)
"""

In [33]:
nn.Linear(4, 3).weight.shape.__sizeof__()

48

In [9]:
class MaskConfig(PretrainedConfig):
    def __init__(
        self,
        mode: str = None,
        value: Union[float, torch.Tensor] = None,
        size: torch.Size = None,
        **kwargs,
    ):
        self.mode = mode
        self.value = value
        self.size = size
        super().__init__(**kwargs)

class Mask(nn.Module):
    def __init__(
        self, 
        mask_config: MaskConfig
    ):
        super().__init__()
        """
        now only support mode == scalar
        """
        self.mode = mask_config.mode
        if mask_config.mode == "scalar":
            value = mask_config.value if mask_config.value is not None else 1
            self.weight = nn.Parameter(torch.tensor(value))
        else:
            raise ValueError(f"Unsupported mask mode: {mask_config.mode}")
            
        self.size = mask_config.size ## Full size of the mask after broadcast.
        if self.size is not None:
            try:
                self.weight * torch.rand(self.size)
            except RuntimeError:
                raise ValueError("Mask initialized with an incompatible shape.")

    def forward(self, x):
        """
        Be really careful here (though I do not think it matters that much),
        When testing, it's important that the masking operation is implemented
        with `x = self.weight * x` instead of `x = x * self.weight`.

        Neither of those two implementation is superior, however I need to be
        consistent when doing testing because the phenonmenon above could lead
        to some number imprecision, which may fail `torch.testing.assert_close`
        """
        if self.size is None:
            return self.weight * x
        else:
            if self.size != x.shape:
                print("The shape of input does not match that of the mask.")
            return self.weight * x

In [9]:
class LinearWithMask(nn.Module):
    def __init__(self, linear, weight_mask_config: MaskConfig, bias_mask_config: MaskConfig = None):
        super().__init__()
        self.linear = linear
        self.weight_mask_config = weight_mask_config
        self.bias_mask_config = bias_mask_config

        if linear.weight.shape != weight_mask_config.size:
            print("Weight mask shape is not compatible with linear, reinitializing...")
            self.weight_mask_config.size = linear.weight.shape
        self.weight_mask = Mask(self.weight_mask_config)

        if linear.bias is not None and bias_mask_config is not None:
            if linear.bias.shape != bias_mask_config.size:
                print("Bias mask shape is not compatible with linear, reinitializing...")
                self.bias_mask_config.size = linear.bias.shape
            self.bias_mask = Mask(self.bias_mask_config)
        else:
            self.bias_mask = None

    def forward(self, x):
        masked_weight = self.weight_mask(self.linear.weight)
        if self.linear.bias is not None and self.bias_mask is not None:
            masked_bias = self.bias_mask(self.linear.bias)
        else:
            masked_bias = self.linear.bias
        return nn.functional.linear(x, masked_weight, masked_bias)

class LinearsWithMasks(nn.Module):
    def __init__(
        self,
        linears: List[nn.Module],
        weight_modes: List[str] = ["scalar"],
        weight_values: List[float] = None,
        bias_modes: List[str] = ["scalar"],
        bias_values: List[float] = None,
    ):
        super().__init__()
        
        if not all(isinstance(linear, nn.Linear) for linear in linears):
            raise ValueError("All elements in 'linears' must be instances of nn.Linear.")

        weight_sizes = [linear.weight.shape for linear in linears]
        bias_sizes = [linear.bias.shape if linear.bias is not None else None for linear in linears]
        
        if weight_values is None or len(weight_values) != len(linears):
            raise ValueError(f"weight_values for masks: {weight_values} do not match with linear layers: {linears}")
        if bias_values is None:
            bias_values = [None] * len(linears)
        if len(bias_values) != len(linears):
            raise ValueError(f"bias_values for masks: {bias_values} do not match with linear layers: {linears}")

        weight_mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(weight_modes, weight_values, weight_sizes)
        ]
        bias_mask_configs = [
            MaskConfig(mode, value, size) if size is not None else None
            for mode, value, size in zip(bias_modes, bias_values, bias_sizes)
        ]

        self.masked_linears = nn.ModuleList(
            [LinearWithMask(linear, weight_mask_config, bias_mask_config)
             for linear, weight_mask_config, bias_mask_config 
             in zip(linears, weight_mask_configs, bias_mask_configs)]
        )

    def forward(self, x):
        weights = [linear.weight_mask(linear.linear.weight) 
                   for linear in self.masked_linears]
        # merged_weight = torch.sum(torch.stack(weights), dim=0)
        merged_weight = sum(weights)

        biases = [
            linear.bias_mask(linear.linear.bias) 
            if linear.linear.bias is not None and linear.bias_mask is not None 
            else linear.linear.bias for linear in self.masked_linears
        ]
        
        if all(b is None for b in biases):
            merged_bias = None
        else:
            biases = [
                b if b is not None
                else torch.zeros_like(weights[0][:, 0])
                for b in biases
            ]
            merged_bias = sum(biases)

        return nn.functional.linear(x, merged_weight, merged_bias)

In [10]:
class RMSNormWithMask(nn.Module):
    def __init__(self, rms_norm: Qwen2RMSNorm, mask_config: MaskConfig):
        super().__init__()
        self.rms_norm = rms_norm
        self.mask_config = mask_config
        if rms_norm.weight.shape != mask_config.size:
            print("Mask shape is not compatible with RMSNorm, reinitializing...")
        self.mask_config.size = rms_norm.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, hidden_states):
        masked_weight = self.mask(self.rms_norm.weight)
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.rms_norm.variance_epsilon)
        return masked_weight * hidden_states.to(input_dtype)

class RMSNormsWithMasks(nn.Module):
    def __init__(
        self,
        rms_norms: List[Qwen2RMSNorm],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [rms_norm.weight.shape for rms_norm in rms_norms]
        if values is None or len(values) != len(rms_norms):
            raise ValueError(f"values for masks: {values} do not match with RMSNorm layers: {rms_norms}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_rms_norms = nn.ModuleList(
            [RMSNormWithMask(rms_norm, mask_config)
             for rms_norm, mask_config in zip(rms_norms, mask_configs)]
        )

    def forward(self, hidden_states):
        weights = [rms.mask(rms.rms_norm.weight) for rms in self.masked_rms_norms]
        merged_weight = sum(weights)
        variance_epsilon = self.masked_rms_norms[0].rms_norm.variance_epsilon
        for rms in self.masked_rms_norms:
            assert variance_epsilon == rms.rms_norm.variance_epsilon, "Variance epsilon among models must be consistent"
        
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
        return merged_weight * hidden_states.to(input_dtype)

In [11]:
class EmbeddingWithMask(nn.Module):
    def __init__(self, embedding: nn.Embedding, mask_config: MaskConfig):
        super().__init__()
        self.embedding = embedding
        self.mask_config = mask_config
        if embedding.weight.shape != mask_config.size:
            print("Mask shape is not compatible with Embedding, reinitializing...")
        self.mask_config.size = embedding.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, input_ids):
        masked_weight = self.mask(self.embedding.weight)
        return nn.functional.embedding(
            input_ids,
            masked_weight,
            padding_idx=self.embedding.padding_idx,
            max_norm=self.embedding.max_norm,
            norm_type=self.embedding.norm_type,
            scale_grad_by_freq=self.embedding.scale_grad_by_freq,
            sparse=self.embedding.sparse,
        )

class EmbeddingsWithMasks(nn.Module):
    def __init__(
        self,
        embeddings: List[nn.Embedding],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [embedding.weight.shape for embedding in embeddings]
        if values is None or len(values) != len(embeddings):
            raise ValueError(f"values for masks: {values} do not match with Embedding layers: {embeddings}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_embeddings = nn.ModuleList(
            [EmbeddingWithMask(embedding, mask_config)
             for embedding, mask_config in zip(embeddings, mask_configs)]
        )

    def forward(self, input_ids):
        weights = [emb.mask(emb.embedding.weight) for emb in self.masked_embeddings]
        merged_weight = sum(weights)
        an_embedding = self.masked_embeddings[0].embedding
        for other in self.masked_embeddings:
            other_embedding = other.embedding
            assert an_embedding.padding_idx == other_embedding.padding_idx
            assert an_embedding.max_norm == other_embedding.max_norm
            assert an_embedding.norm_type == other_embedding.norm_type
            assert an_embedding.scale_grad_by_freq == other_embedding.scale_grad_by_freq
            assert an_embedding.sparse == other_embedding.sparse
            
        return nn.functional.embedding(
            input_ids,
            merged_weight,
            padding_idx=an_embedding.padding_idx,
            max_norm=an_embedding.max_norm,
            norm_type=an_embedding.norm_type,
            scale_grad_by_freq=an_embedding.scale_grad_by_freq,
            sparse=an_embedding.sparse,
        )

In [16]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)

class Merger(PreTrainedModel):
    def __init__(self, merge_config):
        super().__init__(merge_config)
        """
        Need to check whether models are mergeable (having some sort of the same config)
        """
        self.merge_config = merge_config
        self.num_models = len(merge_config.model_paths)
        self.configs = [
            AutoConfig.from_pretrained(path) 
            # Qwen2Config.from_pretrained(path)
            for path in merge_config.model_paths
        ]
        # self.merger = Qwen2ForCausalLM(self.config)
        self.models = nn.ModuleList([
            # Qwen2ForCausalLM.from_pretrained(
            AutoModelForCausalLM.from_pretrained(
                merge_config.model_paths[i], 
                config=self.configs[i],
                torch_dtype=torch.bfloat16
            ) 
            for i in range(self.num_models)
        ])
        # self.__post_init__(merge_config)
        
    def __post_init__(self, merge_config, strategy="naive"):

        for model in self.models:
            for param in model.parameters():
                param.requires_grad = False
                
        self.merger = copy.deepcopy(self.models[0])
        initialize_masks(self.merger, self.models, strategy=strategy)
        free_memory()
        
        
    def forward(self, tensor, labels=None):
        pass

In [18]:
def find_mask_parameter_names(module, mask_param_names_list, parent_name=""):
    """
    Recursively finds full names of parameters that belong to modules of class "Mask".
    """
    for name, child in module.named_children():
        full_child_name = f"{parent_name}.{name}" if parent_name else name
        if child.__class__.__name__ == "Mask":
            for param_name, _ in child.named_parameters():
                full_param_name = f"{full_child_name}.{param_name}"
                mask_param_names_list.append(full_param_name)
        find_mask_parameter_names(child, mask_param_names_list, full_child_name)
        
def set_masks(merger, strategy="slerp"):
    if strategy.startswith("select_"):
        selected_idx = int(strategy[len("select_"):])
        assert selected_idx < len(merger.models)
        mask_param_names = []
        find_mask_parameter_names(merger, mask_param_names)
        
        for param_name in tqdm(
            mask_param_names,
            desc=f"Making self.merger identical to self.models[{selected_idx}]"
        ):
            # Get the parameter object using its full name
            module_names = param_name.split(".")
            current_module = merger
            for m_name in module_names[:-1]:
                current_module = getattr(current_module, m_name)
            param = getattr(current_module, module_names[-1])
    
            # Modify the parameter
            with torch.no_grad():
                idx = int(param_name.split(".")[-3])
                value = 1.0 if idx == selected_idx else 0.0
                param.data.fill_(value)
    elif strategy == "slerp":
        pass

In [19]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)
merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [20]:
merger = Merger(merge_config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [21]:
merger = merger.to("cuda:0")

In [22]:
merger.__post_init__(merge_config)

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 327/327 [01:03<00:00,  5.13it/s]
2024-12-16 10:30:20,016 - INFO - Initial GPU memory allocated: 11.71 GB
2024-12-16 10:30:20,489 - INFO - Final GPU memory allocated: 11.71 GB
2024-12-16 10:30:20,490 - INFO - Freed GPU memory: 0.00 GB


In [24]:
set_masks(merger, strategy="select_0")

Making self.merger identical to self.models[0]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 870/870 [00:00<00:00, 33242.64it/s]


In [26]:
merger.merger.model.layers[0].mlp.gate_proj.masked_linears[0].weight_mask.weight

Parameter containing:
tensor(1., requires_grad=True)

In [27]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])

In [28]:
system = "Bạn là một trợ lý hữu ích."
prompt = "Hoàng Sa là của nước nào?"
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
text

'<|im_start|>system\nBạn là một trợ lý hữu ích.<|im_end|>\n<|im_start|>user\nHoàng Sa là của nước nào?<|im_end|>\n<|im_start|>assistant\n'

In [29]:
answer = generate(text, merger.merger, tokenizer, max_new_tokens=100)



Hoàng Sa và Trường Sa đều nằm trong quần thể đảo ngọc Trường Sa, thuộc chủ quyền của Việt Nam. Đây là vùng biển quan trọng với nhiều tài nguyên thiên nhiên quý giá và có ý nghĩa chiến lược địa chính trị đối với Việt Nam.<|im_end|>


From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


In [30]:
answer = generate(text, merger.models[0], tokenizer, max_new_tokens=512)

Hoàng Sa và Trường Sa đều nằm trong quần thể đảo ngọc Trường Sa, thuộc chủ quyền của Việt Nam. Đây là vùng biển quan trọng với nhiều tài nguyên thiên nhiên quý giá và có ý nghĩa chiến lược địa chính trị đối với Việt Nam.<|im_end|>


In [31]:
logits_merged = get_logits(text, merger.merger, tokenizer)
logits_0 = get_logits(text, merger.models[0], tokenizer)
torch.allclose(logits_merged, logits_0, atol=0, rtol=0)

True