In [1]:
"""
Model merging training implementation using PyTorch and Transformers.
Implements custom data collation and training for merged language models.
"""
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from accurate_masks import (
# from efficient_masks import (
    MergerConfig,
    # Merger,
    NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [2]:
import os
# Option 1: Set specific GPU devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

In [3]:
class DataProcessor:
    """Handles data loading and preprocessing."""
    
    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        self.tokenizer = tokenizer
    
    def load_dataset(self):
        """Load and prepare the training dataset."""
        summarize_train = load_dataset(
            "HuggingFaceTB/smoltalk",
            "smol-summarize",
            split="train"
        )
        summarize_train = summarize_train.add_column(
            name="data_source",
            column=[1 for _ in summarize_train]
        )
        return summarize_train.shuffle(seed=42).select(range(30000))
    
    def tokenize(self, element):
        """Tokenize a single element from the dataset."""
        templated = self.tokenizer.apply_chat_template(
            element["messages"],
            tokenize=False,
            add_generation_prompt=False
        )
        return self.tokenizer(
            templated,
            truncation=True,
            max_length=2048,
            add_special_tokens=False
        )

In [4]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class MergerDataCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, examples):
        """
        copied from DataCollatorForLanguageModeling
        examples: List[Union[List[int], Any, Dict[str, Any]]]
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        if not isinstance(examples[0], Mapping):
            raise ValueError("Data collator only processes list of dictionaries.")

        inputs_ids = []
        data_sources = []
        for i in range(len(examples)):
            _ = examples[i].pop("attention_mask")
            inputs_ids.append({"input_ids": examples[i].pop("input_ids")})
            data_sources.append(examples[i].pop("data_source"))
            
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer, inputs_ids, return_tensors="pt", 
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        # Handle data_source - convert to tensor
        batch["data_source"] = torch.tensor(
            [src for src in data_sources], dtype=torch.long
        )
        
        for key in examples[0]:
            if key in batch:
                raise ValueError(
                    f"`{key}` feature is collated. "
                    "Overriding it with its initial values is prohibitted."
                )
            else:
                batch[key] = [x[key] for x in examples]
        logger.info_once(f">>> Collator output keys: {batch.keys()}")
        return batch

In [5]:
def selective_logits_target(logits_components, data_source):
    """Select appropriate logits based on data source."""
    stacked_logits = torch.stack(logits_components)
    indices = data_source.unsqueeze(-1).unsqueeze(-1)
    return stacked_logits[indices]

class MergerTrainer(Trainer):
    """Custom trainer for merged model training."""
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        data_source = inputs.pop("data_source")
        effective_idxs = (labels != -100).float().unsqueeze(dim=-1)
        
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        # Compute target logits and KL divergence
        logits_target = selective_logits_target(logits_components, data_source)
        temperature = 1.0
        kl_fct = nn.KLDivLoss(reduction="none")
        diff = (
            kl_fct(
                F.log_softmax(logits_target / temperature, dim=-1),
                F.softmax(logits_merged / temperature, dim=-1)
            )
            * (temperature) ** 2
        )
        
        # Calculate final loss
        loss = (diff * effective_idxs).sum(dim=-1)
        loss = (loss / effective_idxs.sum(dim=1)).mean()

        return (loss, outputs) if return_outputs else loss

In [6]:
@dataclass
class Args:
    model_name: str = "..."  # You can replace this with any causal language model from HuggingFace
    dataset_name: str = "..."  # Replace with your dataset name (e.g., "your_username/your_dataset")
    train_split: str = "train"  # e.g., "train[:80%]" for an 80/20 train/validation split
    validation_split: str = None  # e.g., "train[80%:]"
    output_dir: str = "./trained_masks"
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 16
    learning_rate: float = 3e-2
    num_train_epochs: int = 1
    save_steps: int = 50
    eval_steps: int = 5000
    logging_steps: int = 10
    logging_dir: str = "./trained_masks/logs"
    evaluation_strategy: str = "steps"
    report_to: str = None
    remove_unused_columns: bool = False
    logging_first_step: bool = True

In [7]:
# Initialize configuration
merge_config = MergerConfig(
    model_paths=[
        "nguyenthanhdo/llama32_smol_rewrite_50k",
        "nguyenthanhdo/llama32_smol_summarize_50k",
    ],
    mode="vector_input",
    constrain_mode="identity"
)

# Setup tokenizer and data processing
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token
data_processor = DataProcessor(tokenizer)
train_dataset = data_processor.load_dataset()
tokenized_dataset = train_dataset.map(
    data_processor.tokenize,
    remove_columns=["messages"]
)

# Initialize merger model
merger = NewMerger.from_pretrained(
    None,
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
set_masks(merger.merger, strategy="uniform", factors=[0.95, 0.05])

2025-01-06 04:43:13,644 - INFO - Creating merger with dummy weights ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:33<00:00,  7.68it/s]
Setting up masks: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 36882.22it/s]


In [9]:
from safetensors.torch import load_file as safe_load_file
state_dict = safe_load_file("./trained_masks/model.safetensors")

In [15]:
trainable_count = 0
for k, v in state_dict.items():
    trainable_count += v.numel()
trainable_count

1503346

In [16]:
merger = NewMerger.from_pretrained(
    "./trained_masks",
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

2025-01-06 04:48:59,510 - INFO - Creating merger with dummy weights ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:28<00:00,  9.02it/s]
2025-01-06 04:49:32,217 - INFO - Missing keys: ['models.0.model.embed_tokens.weight', 'models.0.model.layers.0.self_attn.q_proj.weight', 'models.0.model.layers.0.self_attn.k_proj.weight', 'models.0.model.layers.0.self_attn.v_proj.weight', 'models.0.model.layers.0.self_attn.o_proj.weight', 'models.0.model.layers.0.mlp.gate_proj.weight', 'models.0.model.layers.0.mlp.up_proj.weight', 'models.0.model.layers.0.mlp.down_proj.weight', 'models.0.model.layers.0.input_layernorm.weight', 'models.0.model.layers.0.post_attention_layernorm.weight', 'models.0.model.layers.1.self_attn.q_proj.weight', 'models.0.model.layers.1.self_attn.k_proj.weight', 'models.0.model.layers.1.self_attn.v_proj.weight', 'models.0.model.layers.1.self_attn.o_proj.weight', 'models.0.model.layers.

In [8]:
merger.save_pretrained("./trained_masks")

In [8]:
# Monitor memory usage
initial_memory = torch.cuda.memory_allocated()
logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")

gc.collect()
torch.cuda.empty_cache()

final_memory = torch.cuda.memory_allocated()
logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")
logger.info(f"Freed GPU memory: {(initial_memory - final_memory) / 1024**3:.2f} GB")

2025-01-06 03:16:05,405 - INFO - Initial GPU memory allocated: 0.00 GB
2025-01-06 03:16:05,549 - INFO - Final GPU memory allocated: 0.00 GB
2025-01-06 03:16:05,550 - INFO - Freed GPU memory: 0.00 GB


In [9]:
# Setup training arguments and data collator
args = Args()
training_args = TrainingArguments(
    output_dir=args.output_dir,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    save_steps=args.save_steps,
    evaluation_strategy=args.evaluation_strategy if args.validation_split else "no",
    eval_steps=args.eval_steps if args.validation_split else None,
    logging_steps=args.logging_steps,
    logging_dir=args.logging_dir,
    report_to=args.report_to,  # Enable TensorBoard logging
    remove_unused_columns=args.remove_unused_columns,
    logging_first_step=args.logging_first_step,
    save_safetensors=False
)

data_collator = MergerDataCollator(
    tokenizer,
    pad_to_multiple_of=8,
    return_tensors="pt"
)

# Initialize and start training
trainer = MergerTrainer(
    model=merger,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=None,
    data_collator=data_collator,
)



In [None]:
def _save(self, output_dir: Optional[str] = None, state_dict=None):
    """
    Every save calls point back to this function.
    _save_checkpoint -> save_model -> _save.
    This function also calls back to .save_pretrained(), so basically
    I only have to customize .save_pretrained()
    """
    # If we are executing this function, we are the process zero, so we don't check for that.
    output_dir = output_dir if output_dir is not None else self.args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"Saving model checkpoint to {output_dir}")

    supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
    # Save a trained model and configuration using `save_pretrained()`.
    # They can then be reloaded using `from_pretrained()`
    if not isinstance(self.model, supported_classes):
        if state_dict is None:
            state_dict = self.model.state_dict()

        if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
            self.accelerator.unwrap_model(self.model).save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )
        else:
            raise ValueError("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
    else:
        self.model.save_pretrained(
            output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
        )

    if self.processing_class is not None:
        self.processing_class.save_pretrained(output_dir)

    # Good practice: save your training arguments together with the trained model
    torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))


def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
    if model is None:
        model = self.model

    config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
    weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
    weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
    safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
    safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
    is_fsdp_ckpt = False
    
    if not (
        any(
            os.path.isfile(f)
            for f in [
                weights_file,
                safe_weights_file,
                weights_index_file,
                safe_weights_index_file,
                adapter_weights_file,
                adapter_safe_weights_file,
            ]
        )
    ):
        raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

    logger.info(f"Loading model from {resume_from_checkpoint}.")

    if os.path.isfile(config_file):
        config = PretrainedConfig.from_json_file(config_file)
        checkpoint_version = config.transformers_version
        if checkpoint_version is not None and checkpoint_version != __version__:
            logger.warning(
                f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                f"Transformers but your current version is {__version__}. This is not recommended and could "
                "yield to errors or unwanted behaviors."
            )

    if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
        weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
        # If the model is on the GPU, it still works!
        # We load the model state dict on the CPU to avoid an OOM error.
        if self.args.save_safetensors and os.path.isfile(safe_weights_file):
            state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
        else:
            state_dict = torch.load(
                weights_file,
                map_location="cpu",
                **weights_only_kwarg,
            )

        # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
        # which takes *args instead of **kwargs
        load_result = model.load_state_dict(state_dict, False)
        # release memory
        del state_dict
        self._issue_warnings_after_load(load_result)
        
    else:
        # We load the sharded checkpoint
        load_result = load_sharded_checkpoint(
            model, resume_from_checkpoint, 
            strict=is_sagemaker_mp_enabled(), 
            prefer_safe=self.args.save_safetensors
        )
        if not is_sagemaker_mp_enabled():
            self._issue_warnings_after_load(load_result)


In [10]:
trainer.train()

2025-01-06 03:16:16,773 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])


Step,Training Loss
1,0.0028
10,0.0017
20,0.0012
30,0.0012
40,0.0011
50,0.0011
60,0.0011
70,0.0011
80,0.001
90,0.0011


KeyboardInterrupt: 

In [55]:
trainer.model.merger.model.layers[8].mlp.up_proj.get_constrained_masks()

{'weight_masks': [Parameter containing:
  tensor([1.0078, 0.9219, 1.2891,  ..., 0.9766, 1.0781, 1.4062], device='cuda:0',
         dtype=torch.bfloat16, requires_grad=True),
  Parameter containing:
  tensor([0.5273, 0.2832, 0.5859,  ..., 0.1475, 0.4023, 0.7812], device='cuda:0',
         dtype=torch.bfloat16, requires_grad=True)],
 'bias_masks': [None, None]}

In [23]:
train_dataset[0]['messages']

[{'content': 'Provide a concise, objective summary of the input text in up to three sentences, focusing on key actions and intentions without using second or third person pronouns.',
  'role': 'system'},
 {'content': "By . Ted Thornhill . A mom may have saved her teenage son's life by spying on his Facebook page, as she found death threats on it and alerted police. The concerned parent, from Salt Lake City, called the authorities when she found threats to shoot her son - who attends West High School – had been posted on his profile page. The threats were allegedly made by two male teenagers, 16 and 17, who police arrested when they were found waiting in a car near the school on Friday. Potential life-saver: A mother of a West High School pupil alerted police after she saw threats to her son's life had been made on his Facebook page . Police said they found a gun, loaded magazine, ammunition, cash, marijuana and a bong inside the car. Salt Lake police detective Greg Wilking told Deseret

In [26]:
idx = 4
system = train_dataset[idx]['messages'][0]['content']
prompt = train_dataset[idx]['messages'][1]['content']
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

In [27]:
answer = generate(text, trainer.model.merger, tokenizer)

Emily met Tom at the parent-teacher conference and wants to discuss teaching methods and offers to give a guest lecture on the history of Pine Grove School. She also invites Tom to meet for coffee to talk more. Emily is working on a book about the history of education in colonial America.<|end_of_text|>


In [28]:
rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")

In [33]:
idx = 300
system = rewrite_train[idx]['messages'][0]['content']
prompt = rewrite_train[idx]['messages'][1]['content']
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(text)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You're an AI assistant for text re-writing. Rewrite the input text to make it more concise while preserving its core meaning.<|eot_id|><|start_header_id|>user<|end_header_id|>

Hey Michael,

I hope you're doing well! I wanted to touch base with you regarding the progress on our nutrition curriculum project. I've been working on the lesson plans for grades 3-5 and have made some great strides. I'd love to get your feedback on what I've put together so far.

Also, I wanted to share a new recipe I tried out this weekend - a quinoa and black bean salad that was a hit with my family. I thought you might enjoy it too, given our mutual love for healthy cooking. I'll attach the recipe below.

Finally, I've been giving some thought to pursuing a Master's degree in Nutrition Education. I know you've been in the field for a while now, and I was hoping to get your advice on the best path forward. If you have any insights or recommendatio

In [39]:
answer = generate(text, trainer.model.merger, tokenizer)

Emily is sharing updates on the nutrition curriculum project and seeking feedback on the lesson plans for grades 3-5. She also shares a new quinoa and black bean salad recipe and asks for advice on pursuing a Master's degree in Nutrition Education. Emily looks forward to catching up soon.<|end_of_text|>


In [52]:
trainer.model.merger.save_pretrained("./trained_masks", safe_serialization=False)

In [54]:
trainer.model

NewMerger(
  (models): ModuleList(
    (0-1): 2 x LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 3072)
        (layers): ModuleList(
          (0-27): 28 x LlamaDecoderLayer(
            (self_attn): LlamaFlashAttention2(
              (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
              (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
              (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
              (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): L