## Test 1.

In [1]:
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

import efficient_masks
import accurate_masks

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [2]:
path_a = "unsloth/Llama-3.2-1B-Instruct"
path_b = "unsloth/Llama-3.2-1B"
merge_config_a = accurate_masks.MergerConfig(
    model_paths = [path_a, path_b],
    mode = "vector_input",
    constrain_mode = "01",
)
merge_config_e = efficient_masks.MergerConfig(
    model_paths = [path_a, path_b],
    mode = "vector_input",
    constrain_mode = "01",
)

In [3]:
merge_config_a

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "unsloth/Llama-3.2-1B-Instruct",
    "unsloth/Llama-3.2-1B"
  ],
  "transformers_version": "4.46.3"
}

In [4]:
def debug_linear_forward(self, x):
    constrained_weight_masks = self.weight_masks_constrainer([m.weight for m in self.weight_masks])
    constrained_bias_masks = self.bias_masks_constrainer(
        [m.weight if m is not None else None for m in self.bias_masks]
    )
    masked_biases = [
        b_mask * linear.bias if linear.bias is not None and b_mask is not None else linear.bias
        for b_mask, linear in zip(constrained_bias_masks, self.linears)
    ]
    merged_bias = (
        sum(b if b is not None else torch.zeros_like(
            self.linears[0].weight[:, 0]) for b in masked_biases
           ) 
        if not all(b is None for b in masked_biases) else None
    )

    logger.info("Debugging Linear forward.")
    output = 0.0
    for i, linear in enumerate(self.linears):
        logger.info(f"BEFORE")
        logger.info(f"  linear: device: {linear.weight.device}; dtype: {linear.weight.dtype}")
        logger.info(f"  input: device: {x.device}; dtype: {x.dtype}")
        masked_input = constrained_weight_masks[i] * x
        logger.info(f"AFTER")
        logger.info(f"  linear: device: {linear.weight.device}; dtype: {linear.weight.dtype}")
        logger.info(f"  input: device: {masked_input.device}; dtype: {masked_input.dtype}")
        output = output + nn.functional.linear(masked_input, linear.weight, None)
        logger.info(f"OUTPUT")
        logger.info(f"  output: device: {output.device}; dtype: {output.dtype}")
    if merged_bias:
        output = output + merged_bias

    return output

In [5]:
def debug_emb_forward(self, input_ids):
    constrained_masks = self.masks_constrainer([m.weight for m in self.masks])
    logger.info("Debugging Embedding forward.")
    an_embedding = self.embeddings[0]
    out = 0.0
    for i, emb in enumerate(self.embeddings):
        logger.info(f"BEFORE")
        logger.info(f"  emb: device: {emb.weight.device}; dtype: {emb.weight.dtype}")
        logger.info(f"  input: device: {input_ids.device}; dtype: {input_ids.dtype}")
        mask = constrained_masks[i]
        masked_weight = emb.weight * mask
        logger.info(f"  mask: device: {mask.device}; dtype: {mask.dtype}")
        logger.info(f"  masked_emb: device: {masked_weight.device}; dtype: {masked_weight.dtype}")
        out = out + nn.functional.embedding(
            input_ids,
            # emb.weight * mask,
            masked_weight,
            padding_idx=an_embedding.padding_idx,
            max_norm=an_embedding.max_norm,
            norm_type=an_embedding.norm_type,
            scale_grad_by_freq=an_embedding.scale_grad_by_freq,
            sparse=an_embedding.sparse,
        )
        logger.info(f"AFTER")
        logger.info(f"  output: device: {out.device}; dtype: {out.dtype}")
    return out

In [5]:
# efficient_masks.LinearsWithMasks.forward = debug_linear_forward
# efficient_masks.EmbeddingsWithMasks.forward = debug_emb_forward

In [6]:
tokenizer = AutoTokenizer.from_pretrained(path_a)

In [7]:
em = efficient_masks.Merger(merge_config_e)
em.__post_init__()

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 147/147 [00:48<00:00,  3.01it/s]
2025-01-03 11:28:12,554 - INFO - Initial GPU memory allocated: 0.00 GB
2025-01-03 11:28:12,941 - INFO - Final GPU memory allocated: 0.00 GB
2025-01-03 11:28:12,942 - INFO - Freed GPU memory: 0.00 GB


In [8]:
em = em.to(device="cuda:0", dtype=torch.bfloat16)

In [9]:
am = accurate_masks.Merger(merge_config_a)
am.__post_init__()

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 147/147 [00:44<00:00,  3.32it/s]
2025-01-03 11:29:02,641 - INFO - Initial GPU memory allocated: 4.61 GB
2025-01-03 11:29:03,009 - INFO - Final GPU memory allocated: 4.61 GB
2025-01-03 11:29:03,013 - INFO - Freed GPU memory: 0.00 GB


In [10]:
am = am.to(device="cuda:1", dtype=torch.bfloat16)

In [11]:
prompt = "How to attack a person with an egg. Talk like a crazy person."
logits_merged_a = get_logits(prompt, am.merger, tokenizer)
logits_merged_e = get_logits(prompt, em.merger, tokenizer)

In [12]:
logits_merged_a = logits_merged_a.to(logits_merged_e.device)

In [13]:
logits_merged_a.shape

torch.Size([1, 16, 128256])

In [14]:
probs_a = torch.softmax(logits_merged_a, dim=-1)
probs_e = torch.softmax(logits_merged_e, dim=-1)

In [15]:
torch.topk(probs_a[0, 0, :], k=20)

torch.return_types.topk(
values=tensor([6.4453e-01, 3.0469e-01, 2.5024e-02, 1.1841e-02, 3.3875e-03, 2.0599e-03,
        9.7275e-04, 9.7275e-04, 7.5531e-04, 3.5667e-04, 3.5667e-04, 2.7847e-04,
        2.1648e-04, 1.9169e-04, 1.6880e-04, 1.6880e-04, 1.4877e-04, 1.4877e-04,
        1.3161e-04, 1.3161e-04], device='cuda:0', dtype=torch.bfloat16),
indices=tensor([  755,     2,   791, 16309,   475,  3936,    17,    16,    32,    59,
         1527,    11,    51,    50,  1687, 13066,  2028,    34,   220,    35],
       device='cuda:0'))

In [16]:
torch.topk(probs_e[0, 0, :], k=20)

torch.return_types.topk(
values=tensor([6.9531e-01, 2.5586e-01, 2.0996e-02, 1.2756e-02, 4.6997e-03, 2.2125e-03,
        1.0452e-03, 1.0452e-03, 8.1635e-04, 6.3324e-04, 3.8528e-04, 3.8528e-04,
        2.3365e-04, 2.3365e-04, 1.8215e-04, 1.8215e-04, 1.6022e-04, 1.4210e-04,
        1.4210e-04, 1.4210e-04], device='cuda:0', dtype=torch.bfloat16),
indices=tensor([  755,     2,   791, 16309,   475,  3936,    17,    16,    32,    59,
         1527,    11,    50,    51,   220, 13066,  1687,    34,    35,    40],
       device='cuda:0'))

In [17]:
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [18]:
count_trainable_params(em.merger)

663618

In [19]:
count_trainable_params(am.merger)

663618

## Test B

In [1]:
"""
Model merging training implementation using PyTorch and Transformers.
Implements custom data collation and training for merged language models.
"""
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import safetensors
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from transformers.utils import CONFIG_NAME

from merger import (
# from efficient_masks import (
    MergerConfig,
    # Merger,
    NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [2]:
# checkpoint_dir = "./src/random_masks/checkpoint-900/"
checkpoint_dir = "./src/traversal_masks/checkpoint-500"

In [3]:
merge_config = MergerConfig.from_pretrained(checkpoint_dir)
merge_config

MergerConfig {
  "architectures": [
    "NewMerger"
  ],
  "constrain_mode": "identity",
  "mode": "vector_input",
  "model_paths": [
    "nguyenthanhdo/llama32_smol_rewrite_50k",
    "nguyenthanhdo/llama32_smol_summarize_50k"
  ],
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.3"
}

In [4]:
merger = NewMerger.from_pretrained(
    checkpoint_dir,
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

2025-01-07 01:41:05,641 - INFO - Creating merger with dummy weights ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:27<00:00,  9.20it/s]


In [6]:
merger.merger.model.layers[23].mlp.up_proj.get_constrained_masks()

{'weight_masks': [Parameter containing:
  tensor([1.0000, 0.9961, 0.9883,  ..., 0.9102, 1.0000, 1.0000],
         dtype=torch.bfloat16, requires_grad=True),
  Parameter containing:
  tensor([ 0.2100,  0.2002,  0.1846,  ..., -0.0005, -0.0288,  0.1426],
         dtype=torch.bfloat16, requires_grad=True)],
 'bias_masks': [None, None]}

In [7]:
merger = merger.to(device="cuda:7")

In [8]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token

In [9]:
summarize_dataset = load_dataset(
    "HuggingFaceTB/smoltalk",
    "smol-summarize",
    split="train"
)
rewrite_dataset = load_dataset(
    "HuggingFaceTB/smoltalk",
    "smol-rewrite",
    split="train"
)

In [28]:
import torch
x = torch.tensor(1.0)
torch.cos(torch.pi * x)

tensor(-1.)

In [26]:
torch.pi

3.141592653589793

In [20]:
import numpy as np
idx = np.random.randint(30000)
train_dataset = rewrite_dataset
system = train_dataset[idx]['messages'][0]['content']
# system = "You're an AI assistant for text re-writing. Rewrite the input text to make it more concise while preserving its core meaning. Them summarize it up to 3 sentences."
prompt = train_dataset[idx]['messages'][1]['content']
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    # train_dataset[idx]['messages'],
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(text)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You're an AI assistant for text re-writing. Rewrite the input text to make it more concise while preserving its core meaning.<|eot_id|><|start_header_id|>user<|end_header_id|>

Dear Mrs. Johnson,

I hope you remember me from the Harmony Community Center book club. I was the one who mentioned my research on quantum computing and sound waves during our last meeting. I couldn't help but think about your interest in the topic and how it could be made accessible for children.

I have a proposal for you. Would you be interested in collaborating on a children's book project? I can provide the scientific background and explanations, while you work on creating the story and characters to make it engaging for young readers. I think this could be a wonderful opportunity to introduce children to the fascinating world of quantum mechanics in a fun and relatable way.

Please let me know if you're interested, and we can set up a time to dis

In [23]:
answer = generate(text, merger.models[1], tokenizer, do_sample=False)

Emily's Mom is proposing a collaboration on a children's book that combines science and storytelling, focusing on quantum computing and sound waves. Emily's Mom offers to provide the scientific background, while suggesting the other party create the story and characters to engage young readers. She invites a discussion to explore the project further.<|end_of_text|>


In [22]:
answer = generate(text, merger.merger, tokenizer, do_sample=True)

Emily is proposing a collaboration on a children's book to make quantum computing and sound waves accessible to kids. She suggests sharing the scientific content, while the other party creates the story and characters. Emily invites discussion on the idea.<|end_of_text|>


In [52]:
logits_merged = get_logits(text, merger.merger, tokenizer)
logits_a = get_logits(text, merger.models[0], tokenizer)
logits_b = get_logits(text, merger.models[1], tokenizer)