In [3]:
plan = """
1. prepare data
2. define model
    - model a, mask a
    - model b, mask b
make only masks trainable.
make sure it has correct inference.
3. define loss function in Trainer.
"""

## prepare data

## modeling

In [1]:
import math
from typing import List, Optional, Tuple, Union

import datasets
import torch
import torch.nn as nn
from datasets import load_dataset

from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM
)

from configuration_qwen2 import Qwen2Config

## attempt 1

In [3]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)

In [4]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [4]:
merge_config.model_paths

['/workspace/models/Arcee-VyLinh/', '/workspace/models/Qwen2.5-Coder-3B/']

In [5]:
class Merger(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.model_paths[0]
        )
        
        self.models = nn.ModuleList([
            Qwen2ForCausalLM.from_pretrained(
            # AutoModelForCausalLM.from_pretrained(
                model_path,
                # torch_dtype=torch.bfloat16,
                # device_map={"":0}
            ) for model_path in config.model_paths
        ])
        self.__post_init__()
        
    def __post_init__(self):
        # self.masks = torch.nn
        pass
        
    def forward(self, tensor, labels=None):
        """
        activations = []
        for i in range(num_layers):
            L1 = models[0].layers[i]
            L2 = models[1].layers[i]
            Lm = alpha * L1 + beta * L2
            h1 = L1(h)
            h2 = L2(h)
            h = Lm(h)
            activations.append({
                "1": h1, "2": h2, "merged": copy(h)
            })
        """

        """
        - embed_tokens
        - norm
        - layers
            - input_layernorm
            - self_attn
            - mlp
            - post_attention_norm
        - lm_head
        """
        pass

In [6]:
merger = Merger(merge_config)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
merger.models[0].device

device(type='cpu')

In [9]:
attn1 = merger.models[0].model.layers[0].self_attn
attn2 = merger.models[1].model.layers[0].self_attn
mlp1 = merger.models[0].model.layers[0].mlp
mlp2 = merger.models[1].model.layers[0].mlp

In [10]:
attn1.q_proj.weight

Parameter containing:
tensor([[-0.0159, -0.0432, -0.0080,  ...,  0.0081,  0.0096,  0.0132],
        [-0.0330,  0.0110,  0.0085,  ...,  0.0226, -0.0082,  0.0457],
        [-0.0092,  0.0111, -0.0134,  ...,  0.0298,  0.0113, -0.0038],
        ...,
        [-0.0085,  0.0601, -0.0325,  ...,  0.0525, -0.0222,  0.0403],
        [-0.0374, -0.0325,  0.0620,  ..., -0.0206,  0.0806,  0.0376],
        [ 0.0356,  0.0151,  0.0087,  ..., -0.0306, -0.0072,  0.0378]],
       requires_grad=True)

In [11]:
import torch
import torch.nn as nn
import copy
from typing import List, Dict

def merge_linear(weights: List[nn.Linear], factors: List[float]) -> nn.Linear:
    """
    Merges multiple linear layers by taking a weighted average of their weights and biases.

    Args:
        weights: A list of nn.Linear layers to merge.
        factors: A list of scaling factors corresponding to each layer in 'weights'.

    Returns:
        A new nn.Linear layer that is the weighted average of the input layers.

    Raises:
        ValueError: If the number of weights and factors don't match, or if the
                    layers have incompatible dimensions, device or dtype.
    """
    if len(weights) != len(factors):
        raise ValueError("The number of weights and factors must be equal.")

    # Check for compatibility, device, and dtype
    device = weights[0].weight.device
    dtype = weights[0].weight.dtype
    if not all(
        w.in_features == weights[0].in_features
        and w.out_features == weights[0].out_features
        and w.weight.device == device
        and w.weight.dtype == dtype
        for w in weights
    ):
        raise ValueError(
            "Incompatible linear layers for merging. They must have the same in_features, out_features, device, and dtype."
        )

    # Create a new linear layer with the same dimensions, device and dtype
    merged_linear = nn.Linear(
        in_features=weights[0].in_features,
        out_features=weights[0].out_features,
        bias=False,
        device=device,
        dtype=dtype
    )

    # Calculate the merged weight and bias
    merged_weight = torch.zeros_like(weights[0].weight)
    merged_bias = (
        torch.zeros_like(weights[0].bias, device=device, dtype=dtype)
        if weights[0].bias is not None
        else None
    )

    for i, w in enumerate(weights):
        merged_weight += factors[i] * w.weight
        if w.bias is not None:
            if merged_bias is None:
                raise ValueError("Cannot merge linear layers if only some have biases.")
            merged_bias += factors[i] * w.bias

    # Assign the merged weight and bias to the new layer
    with torch.no_grad():
        merged_linear.weight.copy_(merged_weight)
        if merged_bias is not None:
            merged_linear.bias = nn.Parameter(merged_bias)

    return merged_linear

def merge_module_recursive(
    target_module: nn.Module, modules_dict: Dict[str, List[nn.Module]], factors: List[float]
):
    """
    Recursively merges multiple modules by taking a weighted average of their Linear layer weights and biases.

    Args:
        target_module: The target module where the merged weights will be stored.
        modules_dict: A dictionary where keys are module names and values are lists of modules to merge.
        factors: A list of scaling factors corresponding to each list of modules.
    """

    for name, module in target_module.named_modules():
        if isinstance(module, nn.Linear):
            if name not in modules_dict:
                raise ValueError(
                    f"Missing module {name} in modules_dict. Make sure all linear layer weights are provided"
                )
            merged_linear = merge_linear(modules_dict[name], factors)
            # Find the parent module
            parent_module_name = ".".join(name.split(".")[:-1])
            layer_name = name.split(".")[-1]

            if parent_module_name:
                parent_module = target_module.get_submodule(parent_module_name)
            else:
                parent_module = target_module

            # Replace the original Linear layer with merged one
            setattr(parent_module, layer_name, merged_linear)

def merge_modules(modules: List[nn.Module], factors: List[float]) -> nn.Module:
    """
    Merges multiple modules by taking a weighted average of their Linear layer weights and biases.
    The merged weights are stored into a deepcopy of the first module in the list.

    Args:
        modules: A list of nn.Modules to merge.
        factors: A list of scaling factors corresponding to each module in 'modules'.

    Returns:
        A new nn.Module that is the weighted average of the input modules.

    Raises:
        ValueError: If the number of modules and factors don't match.
    """
    if len(modules) != len(factors):
        raise ValueError("The number of modules and factors must be equal.")

    # Check device and dtype consistency across all modules
    device = modules[0].parameters().__next__().device
    dtype = modules[0].parameters().__next__().dtype
    if not all(p.device == device and p.dtype == dtype for module in modules for p in module.parameters()):
        raise ValueError("All modules must be on the same device and have the same dtype.")

    # Create a deep copy of the first module to store the merged weights
    merged_module = copy.deepcopy(modules[0])

    # Dictionary to hold corresponding linear layers from each module
    modules_to_merge = {
        name: [] for name, _ in merged_module.named_modules() if isinstance(_, nn.Linear)
    }
    for module in modules:
        for name, layer in module.named_modules():
            if isinstance(layer, nn.Linear):
                modules_to_merge[name].append(layer)

    # Merge the modules recursively
    merge_module_recursive(merged_module, modules_to_merge, factors)

    # Ensure the merged module has the correct device and dtype
    merged_module.to(device=device, dtype=dtype)

    return merged_module

In [12]:
merged_mlp = merge_modules(modules=[mlp1, mlp2], factors=[0.5, 0.5])

In [32]:
Qwen2MLP.forward??

[0;31mSignature:[0m [0mQwen2MLP[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mhidden_state[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Define the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
[0;31mSource:[0m   
    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mhidden_state[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0mself[0m[0;34m.[0m[0mdown_proj[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mact_fn[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mgate_proj[0m[0;34m([0m[0mhidden_state[0m[0;34m)[0m[0;34m)[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mup_proj[0m[0;34m([0

In [13]:
def forward_merged_mlp(x: torch.Tensor, modules: List[nn.Module], factors: List[float]) -> torch.Tensor:
    """
    Performs a forward pass that simulates the behavior of a merged MLP module.

    Args:
        modules: A list of MLP modules (e.g., Qwen2MLP instances).
        factors: A list of scaling factors corresponding to each module in 'modules'.
        x: The input tensor.

    Returns:
        The output tensor after the forward pass.
    """
    factors = torch.tensor(factors).to(x.device, dtype=x.dtype).view(-1, 1, 1, 1)  # Reshape factors for broadcasting
    gate_output = torch.stack([m.gate_proj(x) for m in modules]).mul(factors).sum(0)
    up_output = torch.stack([m.up_proj(x) for m in modules]).mul(factors).sum(0)
    act_output = modules[0].act_fn(gate_output)  # Assuming all modules have the same activation function
    result = torch.stack([m.down_proj(act_output * up_output) for m in modules]).mul(factors).sum(0)
    return result

In [15]:
import torch
device = merged_mlp.parameters().__next__().device
dtype = merged_mlp.parameters().__next__().dtype
x = torch.rand(1, 4, 2048).to(device, dtype=dtype)
o1 = merged_mlp(x)
o2 = forward_merged_mlp(x, modules=[mlp1, mlp2], factors=[0.5, 0.5])

In [18]:
torch.testing.assert_close(o1, o2)

In [58]:
modules[0].down_proj.weight
gate_output.dtype

torch.float32

In [38]:
import torch
device = "cuda:0"
h = torch.rand(1, 4, 2048, dtype=torch.bfloat16).to(device)
p = torch.arange(4, dtype=torch.bfloat16, device=device).unsqueeze(0)
attn1.forward(h, position_ids=p)

(tensor([[[ 0.3750, -0.3301, -0.0016,  ..., -0.0093, -0.2695,  0.0986],
          [ 0.2930, -0.2949,  0.0933,  ...,  0.0483, -0.2949, -0.0649],
          [ 0.2676, -0.3789,  0.1973,  ...,  0.0134, -0.7227,  0.1367],
          [ 0.2715, -0.2988, -0.0547,  ..., -0.1777, -0.7695,  0.0850]]],
        device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>),
 None,
 None)

In [50]:
def count_parameters(model, param_bits):
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0
    
    for param in model.parameters():
        num_params = param.numel()  # Get the number of elements in the parameter
        total_params += num_params
        if param.requires_grad:
            trainable_params += num_params
        else:
            non_trainable_params += num_params

    total_gigabytes = total_params * (param_bits / 8) / (1024**3)
    memory = f"{total_gigabytes:.2f} GB"
    
    return total_params, memory

In [55]:
count_parameters(attn1, 16), count_parameters(mlp1, 16)

((9439744, '0.02 GB'), (67633152, '0.13 GB'))

## attemtp 2

In [2]:
from utils import are_tokenizers_same
are_tokenizers_same(
    paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

2024-12-12 07:12:16,406 - INFO - Comparing tokenizer at /workspace/models/Arcee-VyLinh/ with tokenizer at /workspace/models/Qwen2.5-Coder-3B/
2024-12-12 07:12:16,409 - INFO - Tokenizer at /workspace/models/Arcee-VyLinh/ and /workspace/models/Qwen2.5-Coder-3B/ are the same based on the defined criteria


True

In [3]:
class MaskConfig(PretrainedConfig):
    def __init__(
        self,
        mode: str = None,
        value: Union[float, torch.Tensor] = None,
        size: torch.Size = None,
        **kwargs,
    ):
        self.mode = mode
        self.value = value
        self.size = size
        super().__init__(**kwargs)

class Mask(nn.Module):
    def __init__(
        self, 
        mask_config: MaskConfig
    ):
        super().__init__()
        """
        now only support mode == scalar
        """
        self.mode = mask_config.mode
        if mask_config.mode == "scalar":
            value = mask_config.value if mask_config.value is not None else 1
            self.mask = nn.Parameter(torch.tensor(value)) # Corrected typo here
        else:
            raise ValueError(f"Unsupported mask mode: {mask_config.mode}")
            
        self.size = mask_config.size ## Full size of the mask after broadcast.
        try:
            self.mask * torch.rand(self.size)
        except RuntimeError:
            print("mask initialized with an incompatible shape.")

    def forward(self, x):
        x = self.mask * x
        return x

In [6]:
torch.Size((4, 8))

torch.Size([4, 8])

In [9]:
mask_config = MaskConfig(
    mode="scalar", value=0.5, size=torch.Size((4, 8))
)
mask_config

MaskConfig {
  "mode": "scalar",
  "size": [
    4,
    8
  ],
  "transformers_version": "4.46.3",
  "value": 0.5
}

In [51]:
class LinearWithMask(nn.Module):
    def __init__(self, linear, mask_config: MaskConfig):
        super().__init__()
        self.linear = linear
        self.mask_config = mask_config
        if linear.weight.shape != mask_config.size:
            print("Mask shape is not imcompatible with linear, reinitializing...")
        self.mask_config.size = linear.weight.shape
        self.mask = Mask(self.mask_config)
        
    def forward(self, x):
        masked_linear = self.mask(self.linear.weight)
        return nn.functional.linear(x, masked_linear, self.linear.bias)

class LinearsWithMasks(nn.Module):
    def __init__(
        self, 
        linears: List[nn.Module], 
        modes: List[str] = ["scalar"], 
        values: List[float] = None
    ):
        super().__init__()
        sizes = [linear.weight.shape for linear in linears]
        if values is None or len(values) != len(linears):
            raise ValueError(f"values for masks: {values} do not match with linear layers: {linears}")
            
        mask_configs = [
            MaskConfig(mode, value, size) 
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_linears = nn.ModuleList(
            [LinearWithMask(linear, mask_config) 
             for linear, mask_config in zip(linears, mask_configs)]
        )
        
    def forward(self, x):
        output = 0.0
        for masked_linear in self.masked_linears:
            output += masked_linear(x)
        return output
            
def replace_linears_with_masked(module):
    """
    Recursively replaces Linear layers in a module with LinearWithMask layers.
    
    Args:
      module: The module in which to replace layers.
    """
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(module, name, LinearWithMask(child))
        else:
            replace_linears_with_masked(child)

In [63]:
# --- Testing ---
def test_multiple_components(input_size: int, output_size: int, num_components_list: List[int]):
    for num_components in num_components_list:
        linears = [nn.Linear(input_size, output_size, bias=False) for _ in range(num_components)]
        x = torch.rand(1, input_size)

        for _ in range(10):  # Reduced number of iterations for faster testing in a notebook
            values = np.random.rand(num_components).tolist() # cast to list
            weights_with_masks = LinearsWithMasks(linears=linears, modes=["scalar"] * num_components, values=values)

            individual_outputs = [linear(x) for linear in linears]
            expected_output = sum(val * out for val, out in zip(values, individual_outputs))
            actual_output = weights_with_masks(x)

            assert torch.allclose(actual_output, expected_output, rtol=1e-5, atol=1e-5), "Outputs do not match!"
        print(f"Test with {num_components} components passed!")

# Set seed for reproducibility
torch.manual_seed(42)

# Define input and output sizes
input_size = 4
output_size = 8

# Run tests
test_multiple_components(input_size, output_size, [i + 1 for i in range(20)])

Test with 1 components passed!
Test with 2 components passed!
Test with 3 components passed!
Test with 4 components passed!
Test with 5 components passed!
Test with 6 components passed!
Test with 7 components passed!
Test with 8 components passed!
Test with 9 components passed!
Test with 10 components passed!
Test with 11 components passed!
Test with 12 components passed!
Test with 13 components passed!
Test with 14 components passed!
Test with 15 components passed!
Test with 16 components passed!
Test with 17 components passed!
Test with 18 components passed!
Test with 19 components passed!
Test with 20 components passed!


In [77]:
class RMSNormWithMask(nn.Module):
    def __init__(self, rms_norm: Qwen2RMSNorm, mask_config: MaskConfig):
        super().__init__()
        self.rms_norm = rms_norm
        self.mask_config = mask_config
        if rms_norm.weight.shape != mask_config.size:
            print("Mask shape is not compatible with RMSNorm, reinitializing...")
        self.mask_config.size = rms_norm.weight.shape
        self.mask = Mask(self.mask_config)

    def forward(self, hidden_states):
        masked_weight = self.mask(self.rms_norm.weight)
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.rms_norm.variance_epsilon)
        return masked_weight * hidden_states.to(input_dtype)

class RMSNormsWithMasks(nn.Module):
    def __init__(
        self,
        rms_norms: List[Qwen2RMSNorm],
        modes: List[str] = ["scalar"],
        values: List[float] = None
    ):
        super().__init__()
        sizes = [rms_norm.weight.shape for rms_norm in rms_norms]
        if values is None or len(values) != len(rms_norms):
            raise ValueError(f"values for masks: {values} do not match with RMSNorm layers: {rms_norms}")

        mask_configs = [
            MaskConfig(mode, value, size)
            for mode, value, size in zip(modes, values, sizes)
        ]
        self.masked_rms_norms = nn.ModuleList(
            [RMSNormWithMask(rms_norm, mask_config)
             for rms_norm, mask_config in zip(rms_norms, mask_configs)]
        )

    def forward(self, hidden_states):
        output = 0.0
        for masked_rms_norm in self.masked_rms_norms:
            output += masked_rms_norm(hidden_states)
        return output

def test_multiple_rms_norm_components(hidden_size: int, num_components_list: List[int]):
    for num_components in num_components_list:
        rms_norms = [Qwen2RMSNorm(hidden_size) for _ in range(num_components)]
        hidden_states = torch.rand(2, 4, hidden_size)

        for _ in range(10):
            values = np.random.rand(num_components).tolist()
            rms_norms_with_masks = RMSNormsWithMasks(rms_norms=rms_norms, modes=["scalar"] * num_components, values=values)

            individual_outputs = [rms_norm(hidden_states) for rms_norm in rms_norms]
            expected_output = sum(val * out for val, out in zip(values, individual_outputs))
            actual_output = rms_norms_with_masks(hidden_states)
            
            assert torch.allclose(actual_output, expected_output, rtol=1e-5, atol=1e-5), "Outputs do not match!"
        print(f"Test with {num_components} RMSNorm components passed!")

# Set seed for reproducibility
torch.manual_seed(42)

# Define input and output sizes
input_size = 4
output_size = 8
hidden_size = 16

# Run tests for RMSNormsWithMasks
test_multiple_rms_norm_components(hidden_size, [i+1 for i in range(20)])

Test with 1 RMSNorm components passed!
Test with 2 RMSNorm components passed!
Test with 3 RMSNorm components passed!
Test with 4 RMSNorm components passed!
Test with 5 RMSNorm components passed!
Test with 6 RMSNorm components passed!
Test with 7 RMSNorm components passed!
Test with 8 RMSNorm components passed!
Test with 9 RMSNorm components passed!
Test with 10 RMSNorm components passed!
Test with 11 RMSNorm components passed!
Test with 12 RMSNorm components passed!
Test with 13 RMSNorm components passed!
Test with 14 RMSNorm components passed!
Test with 15 RMSNorm components passed!
Test with 16 RMSNorm components passed!
Test with 17 RMSNorm components passed!
Test with 18 RMSNorm components passed!
Test with 19 RMSNorm components passed!
Test with 20 RMSNorm components passed!


In [66]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)
        
class Merger(PreTrainedModel):
    def __init__(self, merge_config):
        super().__init__(merge_config)
        """
        Need to check whether models are mergeable (having some sort of the same config)
        """
        self.config = Qwen2Config.from_pretrained(
            merge_config.model_paths[0]
        )
        self.merger = Qwen2ForCausalLM(self.config)
        self.__post_init__(merge_config)
        
    def __post_init__(self, merge_config):
        # self.masks = torch.nn
        pass
        
    def forward(self, tensor, labels=None):
        pass

In [67]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)
merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [68]:
model1 = Qwen2ForCausalLM.from_pretrained(
    merge_config.model_paths[0]
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [75]:
model1.model.norm

Qwen2RMSNorm((2048,), eps=1e-06)

In [55]:
mlp1 = model1.model.layers[0].mlp

In [64]:
masked_mlp1 = replace_linears_with_masked(mlp1)

In [68]:
mlp1

Qwen2MLP(
  (gate_proj): LinearWithMask(
    (linear): Linear(in_features=2048, out_features=11008, bias=False)
    (mask): Mask()
  )
  (up_proj): LinearWithMask(
    (linear): Linear(in_features=2048, out_features=11008, bias=False)
    (mask): Mask()
  )
  (down_proj): LinearWithMask(
    (linear): Linear(in_features=11008, out_features=2048, bias=False)
    (mask): Mask()
  )
  (act_fn): SiLU()
)

In [11]:
merger = Merger(merge_config)