In [3]:
plan = """
1. prepare data
2. define model
    - model a, mask a
    - model b, mask b
make only masks trainable.
make sure it has correct inference.
3. define loss function in Trainer.
"""

## prepare data

## modeling

In [1]:
import math
from typing import List, Optional, Tuple, Union

import datasets
import torch
import numpy as np
import torch.nn as nn
from datasets import load_dataset
import logging
import copy

from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM
)

from configuration_qwen2 import Qwen2Config

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## attempt 1

In [3]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)

In [4]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [4]:
merge_config.model_paths

['/workspace/models/Arcee-VyLinh/', '/workspace/models/Qwen2.5-Coder-3B/']

In [5]:
class Merger(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.model_paths[0]
        )
        
        self.models = nn.ModuleList([
            Qwen2ForCausalLM.from_pretrained(
            # AutoModelForCausalLM.from_pretrained(
                model_path,
                # torch_dtype=torch.bfloat16,
                # device_map={"":0}
            ) for model_path in config.model_paths
        ])
        self.__post_init__()
        
    def __post_init__(self):
        # self.masks = torch.nn
        pass
        
    def forward(self, tensor, labels=None):
        """
        activations = []
        for i in range(num_layers):
            L1 = models[0].layers[i]
            L2 = models[1].layers[i]
            Lm = alpha * L1 + beta * L2
            h1 = L1(h)
            h2 = L2(h)
            h = Lm(h)
            activations.append({
                "1": h1, "2": h2, "merged": copy(h)
            })
        """

        """
        - embed_tokens
        - norm
        - layers
            - input_layernorm
            - self_attn
            - mlp
            - post_attention_norm
        - lm_head
        """
        pass

In [6]:
merger = Merger(merge_config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
merger.models[0].device

device(type='cpu')

In [9]:
attn1 = merger.models[0].model.layers[0].self_attn
attn2 = merger.models[1].model.layers[0].self_attn
mlp1 = merger.models[0].model.layers[0].mlp
mlp2 = merger.models[1].model.layers[0].mlp

In [10]:
attn1.q_proj.weight

Parameter containing:
tensor([[-0.0159, -0.0432, -0.0080,  ...,  0.0081,  0.0096,  0.0132],
        [-0.0330,  0.0110,  0.0085,  ...,  0.0226, -0.0082,  0.0457],
        [-0.0092,  0.0111, -0.0134,  ...,  0.0298,  0.0113, -0.0038],
        ...,
        [-0.0085,  0.0601, -0.0325,  ...,  0.0525, -0.0222,  0.0403],
        [-0.0374, -0.0325,  0.0620,  ..., -0.0206,  0.0806,  0.0376],
        [ 0.0356,  0.0151,  0.0087,  ..., -0.0306, -0.0072,  0.0378]],
       requires_grad=True)

In [11]:
import torch
import torch.nn as nn
import copy
from typing import List, Dict

def merge_linear(weights: List[nn.Linear], factors: List[float]) -> nn.Linear:
    """
    Merges multiple linear layers by taking a weighted average of their weights and biases.

    Args:
        weights: A list of nn.Linear layers to merge.
        factors: A list of scaling factors corresponding to each layer in 'weights'.

    Returns:
        A new nn.Linear layer that is the weighted average of the input layers.

    Raises:
        ValueError: If the number of weights and factors don't match, or if the
                    layers have incompatible dimensions, device or dtype.
    """
    if len(weights) != len(factors):
        raise ValueError("The number of weights and factors must be equal.")

    # Check for compatibility, device, and dtype
    device = weights[0].weight.device
    dtype = weights[0].weight.dtype
    if not all(
        w.in_features == weights[0].in_features
        and w.out_features == weights[0].out_features
        and w.weight.device == device
        and w.weight.dtype == dtype
        for w in weights
    ):
        raise ValueError(
            "Incompatible linear layers for merging. They must have the same in_features, out_features, device, and dtype."
        )

    # Create a new linear layer with the same dimensions, device and dtype
    merged_linear = nn.Linear(
        in_features=weights[0].in_features,
        out_features=weights[0].out_features,
        bias=False,
        device=device,
        dtype=dtype
    )

    # Calculate the merged weight and bias
    merged_weight = torch.zeros_like(weights[0].weight)
    merged_bias = (
        torch.zeros_like(weights[0].bias, device=device, dtype=dtype)
        if weights[0].bias is not None
        else None
    )

    for i, w in enumerate(weights):
        merged_weight += factors[i] * w.weight
        if w.bias is not None:
            if merged_bias is None:
                raise ValueError("Cannot merge linear layers if only some have biases.")
            merged_bias += factors[i] * w.bias

    # Assign the merged weight and bias to the new layer
    with torch.no_grad():
        merged_linear.weight.copy_(merged_weight)
        if merged_bias is not None:
            merged_linear.bias = nn.Parameter(merged_bias)

    return merged_linear

def merge_module_recursive(
    target_module: nn.Module, modules_dict: Dict[str, List[nn.Module]], factors: List[float]
):
    """
    Recursively merges multiple modules by taking a weighted average of their Linear layer weights and biases.

    Args:
        target_module: The target module where the merged weights will be stored.
        modules_dict: A dictionary where keys are module names and values are lists of modules to merge.
        factors: A list of scaling factors corresponding to each list of modules.
    """

    for name, module in target_module.named_modules():
        if isinstance(module, nn.Linear):
            if name not in modules_dict:
                raise ValueError(
                    f"Missing module {name} in modules_dict. Make sure all linear layer weights are provided"
                )
            merged_linear = merge_linear(modules_dict[name], factors)
            # Find the parent module
            parent_module_name = ".".join(name.split(".")[:-1])
            layer_name = name.split(".")[-1]

            if parent_module_name:
                parent_module = target_module.get_submodule(parent_module_name)
            else:
                parent_module = target_module

            # Replace the original Linear layer with merged one
            setattr(parent_module, layer_name, merged_linear)

def merge_modules(modules: List[nn.Module], factors: List[float]) -> nn.Module:
    """
    Merges multiple modules by taking a weighted average of their Linear layer weights and biases.
    The merged weights are stored into a deepcopy of the first module in the list.

    Args:
        modules: A list of nn.Modules to merge.
        factors: A list of scaling factors corresponding to each module in 'modules'.

    Returns:
        A new nn.Module that is the weighted average of the input modules.

    Raises:
        ValueError: If the number of modules and factors don't match.
    """
    if len(modules) != len(factors):
        raise ValueError("The number of modules and factors must be equal.")

    # Check device and dtype consistency across all modules
    device = modules[0].parameters().__next__().device
    dtype = modules[0].parameters().__next__().dtype
    if not all(p.device == device and p.dtype == dtype for module in modules for p in module.parameters()):
        raise ValueError("All modules must be on the same device and have the same dtype.")

    # Create a deep copy of the first module to store the merged weights
    merged_module = copy.deepcopy(modules[0])

    # Dictionary to hold corresponding linear layers from each module
    modules_to_merge = {
        name: [] for name, _ in merged_module.named_modules() if isinstance(_, nn.Linear)
    }
    for module in modules:
        for name, layer in module.named_modules():
            if isinstance(layer, nn.Linear):
                modules_to_merge[name].append(layer)

    # Merge the modules recursively
    merge_module_recursive(merged_module, modules_to_merge, factors)

    # Ensure the merged module has the correct device and dtype
    merged_module.to(device=device, dtype=dtype)

    return merged_module

In [12]:
merged_mlp = merge_modules(modules=[mlp1, mlp2], factors=[0.5, 0.5])

In [32]:
Qwen2MLP.forward??

[0;31mSignature:[0m [0mQwen2MLP[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mhidden_state[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Define the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
[0;31mSource:[0m   
    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mhidden_state[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0mself[0m[0;34m.[0m[0mdown_proj[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mact_fn[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mgate_proj[0m[0;34m([0m[0mhidden_state[0m[0;34m)[0m[0;34m)[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mup_proj[0m[0;34m([0

In [13]:
def forward_merged_mlp(x: torch.Tensor, modules: List[nn.Module], factors: List[float]) -> torch.Tensor:
    """
    Performs a forward pass that simulates the behavior of a merged MLP module.

    Args:
        modules: A list of MLP modules (e.g., Qwen2MLP instances).
        factors: A list of scaling factors corresponding to each module in 'modules'.
        x: The input tensor.

    Returns:
        The output tensor after the forward pass.
    """
    factors = torch.tensor(factors).to(x.device, dtype=x.dtype).view(-1, 1, 1, 1)  # Reshape factors for broadcasting
    gate_output = torch.stack([m.gate_proj(x) for m in modules]).mul(factors).sum(0)
    up_output = torch.stack([m.up_proj(x) for m in modules]).mul(factors).sum(0)
    act_output = modules[0].act_fn(gate_output)  # Assuming all modules have the same activation function
    result = torch.stack([m.down_proj(act_output * up_output) for m in modules]).mul(factors).sum(0)
    return result

In [15]:
import torch
device = merged_mlp.parameters().__next__().device
dtype = merged_mlp.parameters().__next__().dtype
x = torch.rand(1, 4, 2048).to(device, dtype=dtype)
o1 = merged_mlp(x)
o2 = forward_merged_mlp(x, modules=[mlp1, mlp2], factors=[0.5, 0.5])

In [18]:
torch.testing.assert_close(o1, o2)

In [58]:
modules[0].down_proj.weight
gate_output.dtype

torch.float32

In [38]:
import torch
device = "cuda:0"
h = torch.rand(1, 4, 2048, dtype=torch.bfloat16).to(device)
p = torch.arange(4, dtype=torch.bfloat16, device=device).unsqueeze(0)
attn1.forward(h, position_ids=p)

(tensor([[[ 0.3750, -0.3301, -0.0016,  ..., -0.0093, -0.2695,  0.0986],
          [ 0.2930, -0.2949,  0.0933,  ...,  0.0483, -0.2949, -0.0649],
          [ 0.2676, -0.3789,  0.1973,  ...,  0.0134, -0.7227,  0.1367],
          [ 0.2715, -0.2988, -0.0547,  ..., -0.1777, -0.7695,  0.0850]]],
        device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>),
 None,
 None)

In [50]:
def count_parameters(model, param_bits):
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0
    
    for param in model.parameters():
        num_params = param.numel()  # Get the number of elements in the parameter
        total_params += num_params
        if param.requires_grad:
            trainable_params += num_params
        else:
            non_trainable_params += num_params

    total_gigabytes = total_params * (param_bits / 8) / (1024**3)
    memory = f"{total_gigabytes:.2f} GB"
    
    return total_params, memory

In [55]:
count_parameters(attn1, 16), count_parameters(mlp1, 16)

((9439744, '0.02 GB'), (67633152, '0.13 GB'))

## attempt 2

In [2]:
from utils import are_tokenizers_same
are_tokenizers_same(
    paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

2024-12-13 09:13:21,147 - INFO - Comparing tokenizer at /workspace/models/Arcee-VyLinh/ with tokenizer at /workspace/models/Qwen2.5-Coder-3B/
2024-12-13 09:13:21,151 - INFO - Tokenizer at /workspace/models/Arcee-VyLinh/ and /workspace/models/Qwen2.5-Coder-3B/ are the same based on the defined criteria


True

In [3]:
class MaskConfig(PretrainedConfig):
    def __init__(
        self,
        mode: str = None,
        value: Union[float, torch.Tensor] = None,
        size: torch.Size = None,
        **kwargs,
    ):
        self.mode = mode
        self.value = value
        self.size = size
        super().__init__(**kwargs)

class Mask(nn.Module):
    def __init__(
        self, 
        mask_config: MaskConfig
    ):
        super().__init__()
        """
        now only support mode == scalar
        """
        self.mode = mask_config.mode
        if mask_config.mode == "scalar":
            value = mask_config.value if mask_config.value is not None else 1
            self.weight = nn.Parameter(torch.tensor(value)) # Corrected typo here
        else:
            raise ValueError(f"Unsupported mask mode: {mask_config.mode}")
            
        self.size = mask_config.size ## Full size of the mask after broadcast.
        try:
            self.weight * torch.rand(self.size)
        except RuntimeError:
            print("mask initialized with an incompatible shape.")

    def forward(self, x):
        x = self.weight * x
        return x

In [4]:
mask_config = MaskConfig(
    mode="scalar", value=0.5, size=torch.Size((4, 8))
)
mask_config

MaskConfig {
  "mode": "scalar",
  "size": [
    4,
    8
  ],
  "transformers_version": "4.46.3",
  "value": 0.5
}

In [5]:
class LinearWithMask(nn.Module):
    def __init__(self, linear, mask_config: MaskConfig):
        super().__init__()
        self.linear = linear
        self.mask_config = mask_config
        if linear.weight.shape != mask_config.size:
            print("Mask shape is not imcompatible with linear, reinitializing...")
        self.mask_config.size = linear.weight.shape
        self.mask = Mask(self.mask_config)
        
    def forward(self, x):
        masked_linear = self.mask(self.linear.weight)
        return nn.functional.linear(x, masked_linear, self.linear.bias)

class LinearsWithMasks(nn.Module):
    def __init__(
        self, 
        linears: List[nn.Module], 
        modes: List[str] = ["scalar"], 
        values: List[float] = None
    ):
        super().__init__()
        sizes = [linear.weight.shape for linear in linears]
        if values is None or len(values) != len(linears):
            raise ValueError(f"values for masks: {values} do not match with linear layers: {linears}")
            
        mask_configs = [
            MaskConfig(mode, value, size) 
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_linears = nn.ModuleList(
            [LinearWithMask(linear, mask_config) 
             for linear, mask_config in zip(linears, mask_configs)]
        )
        
    def forward(self, x):
        output = 0.0
        for masked_linear in self.masked_linears:
            output += masked_linear(x)
        return output
            
def replace_linears_with_masked(module):
    """
    Recursively replaces Linear layers in a module with LinearWithMask layers.
    
    Args:
      module: The module in which to replace layers.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(module, name, LinearWithMask(child))
        else:
            replace_linears_with_masked(child)

In [6]:
# --- Testing ---
def test_multiple_linear_components(input_size: int, output_size: int, num_components_list: List[int]):
    for num_components in num_components_list:
        linears = [nn.Linear(input_size, output_size, bias=False) for _ in range(num_components)]
        x = torch.rand(1, input_size)

        for _ in range(10):  # Reduced number of iterations for faster testing in a notebook
            values = np.random.rand(num_components).tolist() # cast to list
            weights_with_masks = LinearsWithMasks(linears=linears, modes=["scalar"] * num_components, values=values)

            individual_outputs = [linear(x) for linear in linears]
            expected_output = sum(val * out for val, out in zip(values, individual_outputs))
            actual_output = weights_with_masks(x)

            torch.testing.assert_close(actual_output, expected_output, rtol=1e-6, atol=1e-6)
        print(f"Test with {num_components} Linear components passed!")

# Set seed for reproducibility
torch.manual_seed(42)

# Define input and output sizes
input_size = 1024
output_size = 1024

# Run tests
test_multiple_linear_components(input_size, output_size, [i + 1 for i in range(20)])

Test with 1 Linear components passed!
Test with 2 Linear components passed!
Test with 3 Linear components passed!
Test with 4 Linear components passed!
Test with 5 Linear components passed!
Test with 6 Linear components passed!
Test with 7 Linear components passed!
Test with 8 Linear components passed!
Test with 9 Linear components passed!
Test with 10 Linear components passed!
Test with 11 Linear components passed!
Test with 12 Linear components passed!
Test with 13 Linear components passed!
Test with 14 Linear components passed!
Test with 15 Linear components passed!
Test with 16 Linear components passed!
Test with 17 Linear components passed!
Test with 18 Linear components passed!
Test with 19 Linear components passed!
Test with 20 Linear components passed!


In [7]:
class RMSNormWithMask(nn.Module):
    def __init__(self, rms_norm: Qwen2RMSNorm, mask_config: MaskConfig):
        super().__init__()
        self.rms_norm = rms_norm
        self.mask_config = mask_config
        if rms_norm.weight.shape != mask_config.size:
            print("Mask shape is not compatible with RMSNorm, reinitializing...")
        self.mask_config.size = rms_norm.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, hidden_states):
        masked_weight = self.mask(self.rms_norm.weight)
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.rms_norm.variance_epsilon)
        return masked_weight * hidden_states.to(input_dtype)

class RMSNormsWithMasks(nn.Module):
    def __init__(
        self,
        rms_norms: List[Qwen2RMSNorm],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [rms_norm.weight.shape for rms_norm in rms_norms]
        if values is None or len(values) != len(rms_norms):
            raise ValueError(f"values for masks: {values} do not match with RMSNorm layers: {rms_norms}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_rms_norms = nn.ModuleList(
            [RMSNormWithMask(rms_norm, mask_config)
             for rms_norm, mask_config in zip(rms_norms, mask_configs)]
        )

    def forward(self, hidden_states):
        output = 0.0
        for masked_rms_norm in self.masked_rms_norms:
            output += masked_rms_norm(hidden_states)
        return output

In [8]:
def test_multiple_rms_norm_components(hidden_size: int, num_components_list: List[int]):
    for num_components in num_components_list:
        rms_norms = [Qwen2RMSNorm(hidden_size) for _ in range(num_components)]
        hidden_states = torch.rand(2, 4, hidden_size)

        for _ in range(10):
            values = np.random.rand(num_components).tolist()
            rms_norms_with_masks = RMSNormsWithMasks(rms_norms=rms_norms, modes=["scalar"] * num_components, values=values)

            individual_outputs = [rms_norm(hidden_states) for rms_norm in rms_norms]
            expected_output = sum(val * out for val, out in zip(values, individual_outputs))
            actual_output = rms_norms_with_masks(hidden_states)
            
            torch.testing.assert_close(actual_output, expected_output, rtol=1e-6, atol=1e-6)
        print(f"Test with {num_components} RMSNorm components passed!")

# Set seed for reproducibility
torch.manual_seed(42)

# Define input and output sizes
hidden_size = 2048

# Run tests for RMSNormsWithMasks
test_multiple_rms_norm_components(hidden_size, [i+1 for i in range(5)])

Test with 1 RMSNorm components passed!
Test with 2 RMSNorm components passed!
Test with 3 RMSNorm components passed!
Test with 4 RMSNorm components passed!
Test with 5 RMSNorm components passed!


In [9]:
class EmbeddingWithMask(nn.Module):
    def __init__(self, embedding: nn.Embedding, mask_config: MaskConfig):
        super().__init__()
        self.embedding = embedding
        self.mask_config = mask_config
        if embedding.weight.shape != mask_config.size:
            print("Mask shape is not compatible with Embedding, reinitializing...")
        self.mask_config.size = embedding.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, input_ids):
        masked_weight = self.mask(self.embedding.weight)
        return nn.functional.embedding(
            input_ids,
            masked_weight,
            padding_idx=self.embedding.padding_idx,
            max_norm=self.embedding.max_norm,
            norm_type=self.embedding.norm_type,
            scale_grad_by_freq=self.embedding.scale_grad_by_freq,
            sparse=self.embedding.sparse,
        )

class EmbeddingsWithMasks(nn.Module):
    def __init__(
        self,
        embeddings: List[nn.Embedding],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [embedding.weight.shape for embedding in embeddings]
        if values is None or len(values) != len(embeddings):
            raise ValueError(f"values for masks: {values} do not match with Embedding layers: {embeddings}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_embeddings = nn.ModuleList(
            [EmbeddingWithMask(embedding, mask_config)
             for embedding, mask_config in zip(embeddings, mask_configs)]
        )

    def forward(self, input_ids):
        output = 0.0
        for masked_embedding in self.masked_embeddings:
            output += masked_embedding(input_ids)
        return output

In [10]:
def test_multiple_embedding_components(num_embeddings: int, embedding_dim: int, num_components_list: List[int]):
    for num_components in num_components_list:
        embeddings = [nn.Embedding(num_embeddings, embedding_dim) for _ in range(num_components)]
        input_ids = torch.randint(0, num_embeddings, (2, 5))  # Example input_ids

        for _ in range(10):
            values = np.random.rand(num_components).tolist()
            embeddings_with_masks = EmbeddingsWithMasks(embeddings=embeddings, modes=["scalar"] * num_components, values=values)

            individual_outputs = [embedding(input_ids) for embedding in embeddings]
            expected_output = sum(val * out for val, out in zip(values, individual_outputs))
            actual_output = embeddings_with_masks(input_ids)

            torch.testing.assert_close(actual_output, expected_output, rtol=1e-6, atol=1e-6)
        print(f"Test with {num_components} Embedding components passed!")

# Set seed for reproducibility
torch.manual_seed(42)

# Define parameters for Embedding
num_embeddings = 2048
embedding_dim = 2048

# Run tests for EmbeddingsWithMasks
test_multiple_embedding_components(num_embeddings, embedding_dim, [i + 1 for i in range(5)])

Test with 1 Embedding components passed!
Test with 2 Embedding components passed!
Test with 3 Embedding components passed!
Test with 4 Embedding components passed!
Test with 5 Embedding components passed!


In [11]:
import os
import shutil
from safetensors import safe_open

def load_layer(path, layer_idx=33):
	state_dict = {}
	shard_paths = [f for f in os.listdir(path) if f.endswith('.safetensors')]
	for shard_path in sorted(shard_paths, key=lambda x: int(x.split('-')[1])):
		apath = os.path.join(path, shard_path)
		with safe_open(apath, framework="pt", device="cpu") as f:
			for key in f.keys():
				if f"layers.{str(layer_idx)}." in key:
					state_dict[key] = f.get_tensor(key)
	return state_dict

def strip_prefix(state_dict, prefix="model.layers."):
    """Strips 'model.layers.*.' prefix from 'input_layernorm.weight' keys."""
    return {
      k.replace(f"{prefix}{k.split('.')[2]}.", "") if k.startswith(prefix)
      else k: v for k, v in state_dict.items()
    }

In [12]:
def place_masks(target_module, ref_modules):
    """
    Recursively replaces normal components with masked components.
    
    Args:
      module: The module in which to replace layers.
    """
    for name, target_child in target_module.named_children():
        ref_children = [getattr(module, name) for module in ref_modules]
        modes = ["scalar" for _ in ref_children]
        values = [0.0 for _ in ref_children]
        values[0] = 1.0
        if isinstance(target_child, nn.Linear):
            setattr(target_module, name, LinearsWithMasks(
                ref_children, modes, values
            ))
        elif isinstance(target_child, nn.Embedding):
            setattr(target_module, name, EmbeddingsWithMasks(
                ref_children, modes, values
            ))
        elif type(target_child).__name__ == Qwen2RMSNorm.__name__:
            setattr(target_module, name, RMSNormsWithMasks(
                ref_children, modes, values
            ))
        else:
            place_masks(target_child, ref_children)

class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)

class DecoderMerger(PreTrainedModel):
    def __init__(self, merge_config):
        super().__init__(merge_config)
        """
        Need to check whether models are mergeable (having some sort of the same config)
        """
        self.merge_config = merge_config
        self.configs = [Qwen2Config.from_pretrained(path) 
                        for path in merge_config.model_paths]
        
        # self.merger = Qwen2ForCausalLM(self.config)
        self.decoders = nn.ModuleList(
            Qwen2DecoderLayer(config, layer_idx=1) for config in self.configs
        )
        for i in range(len(self.decoders)):
            path = merge_config.model_paths[i]
            state_dict = load_layer(path, layer_idx=1)
            state_dict = strip_prefix(state_dict)
            self.decoders[i].load_state_dict(
                state_dict=state_dict
            )
        self.__post_init__(merge_config)
        
    def __post_init__(self, merge_config):
        self.merger = copy.deepcopy(self.decoders[0])
        place_masks(self.merger, self.decoders)
        
    def forward(self, tensor, labels=None):
        pass
        
class Merger(PreTrainedModel):
    def __init__(self, merge_config):
        super().__init__(merge_config)
        """
        Need to check whether models are mergeable (having some sort of the same config)
        """
        self.merge_config = merge_config
        self.num_models = len(merge_config.model_paths)
        self.configs = [
            AutoConfig.from_pretrained(path) 
            for path in merge_config.model_paths
        ]
        # self.merger = Qwen2ForCausalLM(self.config)
        self.models = nn.ModuleList([
            AutoModelForCausalLM.from_pretrained(
                merge_config.model_paths[i], 
                config=self.configs[i],
                torch_dtype=torch.bfloat16
            ) 
            for i in range(self.num_models)
        ])
        self.__post_init__(merge_config)
        
    def __post_init__(self, merge_config):
        # dummy_config = copy.deepcopy(self.configs[0])
        # dummy_config.update({"hidden_size": 1, "intermediate_size": 1})
        # self.merger = AutoModelForCausalLM.from_config(dummy_config)
        self.merger = copy.deepcopy(self.models[0])
        place_masks(self.merger, self.models)
        
    def forward(self, tensor, labels=None):
        pass

In [17]:
def lerp(
    t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
    return (1 - t) * v0 + t * v1

def weighted_sum(
    factors: List[float], 
    tensors: Union[List[np.ndarray], List[torch.Tensor]]
) -> Union[np.ndarray, torch.Tensor]:
    result = 0.0
    for factor, tensor in zip(factors, tensors):
        result += factor * tensor
    return result

def merge_modules(modules, factors):
    module_out = copy.deepcopy(modules[0])
    out_dict = module_out.state_dict()
    
    tensor_dicts_list = [m.state_dict() for m in modules]
    tensor_names = [key for key in tensor_dicts_list[0].keys()]
    
    for tensor_name in tensor_names:
        tensors_list = [tensor_dicts_list[i][tensor_name]
                       for i in range(len(modules))]
        tensor_computed = (
            weighted_sum(
                factors=factors,
                tensors=tensors_list
            )
            .to(tensors_list[0].dtype)
            .to(tensors_list[0].device)
        )
        out_dict[tensor_name] = tensor_computed
    module_out.load_state_dict(out_dict)
    return module_out

In [14]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)
merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [15]:
merger = Merger(merge_config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
merger.merger.model.embed_tokens.masked_embeddings[1].mask.state_dict()

tensor(0.)

In [37]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])

In [16]:
merger.to("cuda:0")

Merger(
  (models): ModuleList(
    (0-1): 2 x Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 2048)
        (layers): ModuleList(
          (0-35): 36 x Qwen2DecoderLayer(
            (self_attn): Qwen2SdpaAttention(
              (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
              (k_proj): Linear(in_features=2048, out_features=256, bias=True)
              (v_proj): Linear(in_features=2048, out_features=256, bias=True)
              (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
              (rotary_emb): Qwen2RotaryEmbedding()
            )
            (mlp): Qwen2MLP(
              (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
              (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
              (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMS

In [21]:
def compare_modules(module1, module2, rtol=1e-05, atol=1e-08, verbose=True):
    """
    Compares the weights of two modules using torch.testing.assert_close.
    Assumes modules have identical keys.
    """
    state_dict1 = module1.state_dict()
    state_dict2 = module2.state_dict()

    # Iterate directly through the keys of one module's state_dict
    for key in state_dict1:  
        tensor1 = state_dict1[key]
        tensor2 = state_dict2[key]

        # No need for shape check, assumed to be identical
        try:
            torch.testing.assert_close(tensor1, tensor2, rtol=rtol, atol=atol)
            if verbose:
                print(f"  OK: Tensor '{key}' is close within tolerance.")
        except AssertionError as e:
            if verbose:
                print(f"  ERROR: Tensor '{key}' is NOT close within tolerance.")
            raise AssertionError(f"Tensor '{key}' comparison failed: {e}") from e
    print("--- All tensors are identical! ---")

In [65]:
merger.to("cuda:0")

Merger(
  (models): ModuleList(
    (0-1): 2 x Qwen2ForCausalLM(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 2048)
        (layers): ModuleList(
          (0-35): 36 x Qwen2DecoderLayer(
            (self_attn): Qwen2SdpaAttention(
              (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
              (k_proj): Linear(in_features=2048, out_features=256, bias=True)
              (v_proj): Linear(in_features=2048, out_features=256, bias=True)
              (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
              (rotary_emb): Qwen2RotaryEmbedding()
            )
            (mlp): Qwen2MLP(
              (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
              (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
              (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMS

In [24]:
m1 = merger.models[0].model.layers[1]
m2 = merger.models[1].model.layers[2]

mo = merge_modules([m1, m2], [0.0, 1.0])
compare_modules(mo, m2, verbose=True)

mo = merge_modules([m1, m2], [1.0, 0.0])
compare_modules(mo, m1, verbose=True)

# mo = merge_modules([m1, m2], [0.5, 0.5])
# compare_modules(mo, m1, verbose=True)

  OK: Tensor 'self_attn.q_proj.weight' is close within tolerance.
  OK: Tensor 'self_attn.q_proj.bias' is close within tolerance.
  OK: Tensor 'self_attn.k_proj.weight' is close within tolerance.
  OK: Tensor 'self_attn.k_proj.bias' is close within tolerance.
  OK: Tensor 'self_attn.v_proj.weight' is close within tolerance.
  OK: Tensor 'self_attn.v_proj.bias' is close within tolerance.
  OK: Tensor 'self_attn.o_proj.weight' is close within tolerance.
  OK: Tensor 'mlp.gate_proj.weight' is close within tolerance.
  OK: Tensor 'mlp.up_proj.weight' is close within tolerance.
  OK: Tensor 'mlp.down_proj.weight' is close within tolerance.
  OK: Tensor 'input_layernorm.weight' is close within tolerance.
  OK: Tensor 'post_attention_layernorm.weight' is close within tolerance.
--- All tensors are identical! ---
  OK: Tensor 'self_attn.q_proj.weight' is close within tolerance.
  OK: Tensor 'self_attn.q_proj.bias' is close within tolerance.
  OK: Tensor 'self_attn.k_proj.weight' is close withi

In [43]:
cm.state_dict()['q_proj.weight']

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.bfloat16)

In [48]:
from transformers import GenerationConfig, TextStreamer
def generate(prompt, model, tokenizer, max_new_tokens=1024):
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model.device)
    model.eval()
    with torch.no_grad():
        generation_config = GenerationConfig(
            repetition_penalty=1.13,
            max_new_tokens=max_new_tokens,
            temperature=0.4,
            top_p=0.95,
            # top_k=20,
            # bos_token_id=tokenizer.bos_token_id,
            # eos_token_id=tokenizer.eos_token_id,
            # eos_token_id=0, # for open-end generation.
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            use_cache=True,
            return_dict_in_generate=True,
            output_attentions=False,
            output_hidden_states=False,
            output_scores=False,
        )
        streamer = TextStreamer(tokenizer, skip_prompt=True)
        generated = model.generate(
            inputs=input_ids,
            generation_config=generation_config,
            streamer=streamer,
        )
    gen_tokens = generated["sequences"].cpu()[:, len(input_ids[0]):]
    output = tokenizer.batch_decode(gen_tokens)[0]
    output = output.split(tokenizer.eos_token)[0]
    return output.strip()

def get_logits(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        logits = model(**input_ids).logits
    return logits

In [42]:
system = "You are a helpful assistant."
prompt = "Continue this text: A dog is a cat"
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
text

'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nContinue this text: A dog is a cat<|im_end|>\n<|im_start|>assistant\n'

In [51]:
logits_merger = get_logits(text, model=merger.merger, tokenizer=tokenizer)

In [53]:
logits_1 = get_logits(text, model=merger.models[0], tokenizer=tokenizer)

In [55]:
logits_1, logits_merger

(tensor([[[ 8.0625,  6.9688,  3.6094,  ..., -0.8320, -0.8320, -0.8320],
          [ 6.0312,  6.9688,  7.0938,  ..., -1.4219, -1.4219, -1.4219],
          [ 5.8438,  8.7500, 11.1875,  ..., -3.9844, -3.9844, -3.9844],
          ...,
          [ 1.4141,  5.7500, -3.6562,  ..., -0.3164, -0.3164, -0.3164],
          [ 9.3750,  7.2500,  8.1250,  ..., -6.0312, -6.0312, -6.0312],
          [16.2500, 17.8750,  9.1875,  ..., -1.8438, -1.8438, -1.8438]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[ 2.8438,  2.4844,  2.6719,  ..., -0.5352, -0.5352, -0.5352],
          [ 2.7344,  3.3281,  3.1406,  ..., -0.8789, -0.8789, -0.8789],
          [ 2.5938,  2.7344,  4.2500,  ..., -2.1562, -2.1562, -2.1562],
          ...,
          [ 5.8438,  3.5781,  3.8906,  ..., -3.2969, -3.2969, -3.2969],
          [ 6.3438,  3.3438,  5.0312,  ..., -3.2031, -3.2031, -3.2031],
          [ 5.8750,  3.5469,  5.5000,  ..., -3.0000, -3.0000, -3.0000]]],
        device='cuda:0', dtype=torch.bfloat16))

In [57]:
answer = generate(merger.merger, text, max_new_tokens=100)

, A
 the cat in the A is an helpful assistant for A and the cat, A.
 A  help to the A and the cat in A is a helpful assistant
. A cat is not a, A, A

 A

 the cat is A, A
 and the cat in A
 is a useful cat.
 A is
 the cat in A is A
 the A
 and A cat in A is the A
 and A cat, A is A
 in


In [None]:
for name, param in 

In [45]:
type(merger.merger.model.layers[0].mlp)

transformers.models.qwen2.modeling_qwen2.Qwen2MLP

In [42]:
isinstance(merger.models[1].model.norm, Qwen2RMSNorm)

False

In [46]:
type(merger.models[1].model.norm)

transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm

In [1]:
(type(Qwen2RMSNorm(8)))

NameError: name 'Qwen2RMSNorm' is not defined

In [None]:
# generate(merger.merger, text, max_new_tokens=16)

In [None]:
# merger.merger.self_attn.q_proj.masked_linears[1].linear.weight.data_ptr()

In [None]:
# merger.decoders[1].self_attn.q_proj.weight.data_ptr()

In [None]:
# origin = merger.decoders[0].mlp.gate_proj.weight
# ref = merger.merger.mlp.gate_proj.weight

In [None]:
# model1 = Qwen2ForCausalLM.from_pretrained(
#     merge_config.model_paths[0]
# )

In [None]:
# model1.model.norm

In [None]:
# mlp1 = model1.model.layers[0].mlp

In [None]:
# masked_mlp1 = replace_linears_with_masked(mlp1)

In [None]:
# mlp1