In [3]:
plan = """
1. prepare data
2. define model
    - model a, mask a
    - model b, mask b
make only masks trainable.
make sure it has correct inference.
3. define loss function in Trainer.
"""

## prepare data

## modeling

In [1]:
import math
from typing import List, Optional, Tuple, Union

import datasets
import torch
import torch.nn as nn
from datasets import load_dataset

from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser
)

from modeling_qwen2 import (
    Qwen2RMSNorm, 
    Qwen2RotaryEmbedding, 
    Qwen2MLP, 
    Qwen2Attention, 
    Qwen2FlashAttention2, 
    Qwen2SdpaAttention, 
    Qwen2DecoderLayer, 
    Qwen2PreTrainedModel, 
    Qwen2Model, 
    Qwen2ForCausalLM
)

from configuration_qwen2 import Qwen2Config

In [8]:
qwen2_config = Qwen2Config.from_pretrained(
    "/workspace/models/Qwen2.5-Coder-3B/"
)
qwen2_attn = Qwen2Attention(qwen2_config, layer_idx=4)

In [2]:
class MergerConfig(PretrainedConfig):
    def __init__(
        self,
        model_paths: List[str] = None,
        **kwargs,
    ):
        self.model_paths = model_paths
        super().__init__(**kwargs)

In [3]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/models/Arcee-VyLinh/",
        "/workspace/models/Qwen2.5-Coder-3B/"
    ]
)

merge_config

MergerConfig {
  "model_paths": [
    "/workspace/models/Arcee-VyLinh/",
    "/workspace/models/Qwen2.5-Coder-3B/"
  ],
  "transformers_version": "4.46.3"
}

In [4]:
merge_config.model_paths

['/workspace/models/Arcee-VyLinh/', '/workspace/models/Qwen2.5-Coder-3B/']

In [67]:
class Merger(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = AutoTokenizer.from_pretrained(
            config.model_paths[0]
        )
        
        self.models = nn.ModuleList([
            AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map={"":0}
            ) for model_path in config.model_paths
        ])
        self.__post_init__()
        
    def __post_init__(self):
        # self.masks = torch.nn
        pass
        
    def forward(self, tensor, labels=None):
        """
        activations = []
        for i in range(num_layers):
            L1 = models[0].layers[i]
            L2 = models[1].layers[i]
            Lm = alpha * L1 + beta * L2
            h1 = L1(h)
            h2 = L2(h)
            h = Lm(h)
            activations.append({
                "1": h1, "2": h2, "merged": copy(h)
            })
        """

        """
        - embed_tokens
        - norm
        - layers
            - input_layernorm
            - self_attn
            - mlp
            - post_attention_norm
        - lm_head
        """
        pass

In [10]:
merger = Merger(merge_config

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [65]:
merger.models[0]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-35): 36 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=256, bias=True)
          (v_proj): Linear(in_features=2048, out_features=256, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (up_proj): Linear(in_features=2048, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
      )
    )
    (norm):

In [13]:
import torch.nn as nn
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_size = 16
        self.intermediate_size = 32
        self.gate_proj = nn.Linear(16, 32, bias=False)
        self.up_proj = nn.Linear(16, 32, bias=False)
        self.down_proj = nn.Linear(32, 16, bias=False)

    def forward(self, x):
        result = self.down_proj(self.gate_proj(x) * self.up_proj(x))
        return result

mlp = MLP()

In [52]:
attn1 = merger.models[0].model.layers[0].self_attn
attn2 = merger.models[1].model.layers[0].self_attn
mlp1 = merger.models[0].model.layers[0].mlp

In [64]:
for x in attn1.parameters():
    print(x)

Parameter containing:
tensor([[-0.0159, -0.0432, -0.0080,  ...,  0.0081,  0.0096,  0.0132],
        [-0.0330,  0.0110,  0.0085,  ...,  0.0226, -0.0082,  0.0457],
        [-0.0092,  0.0111, -0.0134,  ...,  0.0298,  0.0113, -0.0038],
        ...,
        [-0.0085,  0.0601, -0.0325,  ...,  0.0525, -0.0222,  0.0403],
        [-0.0374, -0.0325,  0.0620,  ..., -0.0206,  0.0806,  0.0376],
        [ 0.0356,  0.0151,  0.0087,  ..., -0.0306, -0.0072,  0.0378]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
Parameter containing:
tensor([-0.0811, -2.3125,  2.1406,  ...,  1.0938, -0.2275,  0.6250],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
Parameter containing:
tensor([[ 0.0016, -0.0310, -0.0079,  ..., -0.0417,  0.0113,  0.0325],
        [ 0.0315, -0.0189,  0.0791,  ...,  0.0364, -0.0352,  0.0344],
        [-0.0115,  0.0466, -0.0225,  ...,  0.0625,  0.0304,  0.0137],
        ...,
        [-0.0605,  0.0025,  0.0486,  ...,  0.0220, -0.0708, -0.0137],
      

In [None]:
from copy import deepcopy
def merge_mini(net1, net2):
    net = deepcopy(net1)
    for 

In [38]:
import torch
device = "cuda:0"
h = torch.rand(1, 4, 2048, dtype=torch.bfloat16).to(device)
p = torch.arange(4, dtype=torch.bfloat16, device=device).unsqueeze(0)
attn1.forward(h, position_ids=p)

(tensor([[[ 0.3750, -0.3301, -0.0016,  ..., -0.0093, -0.2695,  0.0986],
          [ 0.2930, -0.2949,  0.0933,  ...,  0.0483, -0.2949, -0.0649],
          [ 0.2676, -0.3789,  0.1973,  ...,  0.0134, -0.7227,  0.1367],
          [ 0.2715, -0.2988, -0.0547,  ..., -0.1777, -0.7695,  0.0850]]],
        device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>),
 None,
 None)

In [50]:
def count_parameters(model, param_bits):
    total_params = 0
    trainable_params = 0
    non_trainable_params = 0
    
    for param in model.parameters():
        num_params = param.numel()  # Get the number of elements in the parameter
        total_params += num_params
        if param.requires_grad:
            trainable_params += num_params
        else:
            non_trainable_params += num_params

    total_gigabytes = total_params * (param_bits / 8) / (1024**3)
    memory = f"{total_gigabytes:.2f} GB"
    
    return total_params, memory

In [55]:
count_parameters(attn1, 16), count_parameters(mlp1, 16)

((9439744, '0.02 GB'), (67633152, '0.13 GB'))