In [1]:
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

from efficient_masks import (
    MergerConfig,
    Merger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [88]:
# rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")
# summarize_train = load_dataset("HuggingFaceTB/smoltalk", "smol-summarize", split="train")
# rewrite_train = rewrite_train.add_column(name="data_source", column=[0 for _ in rewrite_train])
# summarize_train = summarize_train.add_column(name="data_source", column=[1 for _ in summarize_train])

# train_dataset = datasets.concatenate_datasets([rewrite_train, summarize_train])
# train_mini = train_dataset.shuffle().select(range(5000))
# train_mini.to_json("train_mini.jsonl")

Creating json from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

10973663

In [2]:
train_mini = load_dataset("json", data_files=["train_mini.jsonl"], split="train")

In [26]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class MergerDataCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, examples):
        """
        copied from DataCollatorForLanguageModeling
        examples: List[Union[List[int], Any, Dict[str, Any]]]
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        if not isinstance(examples[0], Mapping):
            raise ValueError("Data collator only processes list of dictionaries.")

        inputs_ids = []
        data_sources = []
        for i in range(len(examples)):
            _ = examples[i].pop("attention_mask")
            inputs_ids.append({"input_ids": examples[i].pop("input_ids")})
            data_sources.append(examples[i].pop("data_source"))
            
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer, inputs_ids, return_tensors="pt", 
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        # Handle data_source - convert to tensor
        batch["data_source"] = torch.tensor(
            [src for src in data_sources], dtype=torch.long
        )
        
        for key in examples[0]:
            if key in batch:
                raise ValueError(f"`{key}` feature is collated. Overriding it with its initial values is prohibitted.")
            else:
                batch[key] = [x[key] for x in examples]
        logger.info(f">>> Collator output keys: {batch.keys()}")
        return batch

In [4]:
merge_config = MergerConfig(
    model_paths = [
        "nguyenthanhdo/llama32_smol_rewrite_50k",
        "nguyenthanhdo/llama32_smol_summarize_50k",
        # "/workspace/HUB_LLM/Llama-3.2-3B-Instruct",
    ],
    mode = "vector_input",
    # mode = "scalar",
    constrain_mode = "01"
)
merge_config

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "nguyenthanhdo/llama32_smol_rewrite_50k",
    "nguyenthanhdo/llama32_smol_summarize_50k"
  ],
  "transformers_version": "4.46.3"
}

In [27]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token
data_collator = MergerDataCollator(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt"
)

In [6]:
def tokenize(element):
    templated = tokenizer.apply_chat_template(
        element["messages"], tokenize=False, add_generation_prompt=False
    )
    outputs = tokenizer(
        templated,
        truncation=True,
        max_length=4096,
        add_special_tokens=False,
        # return_tensors="pt"
    )
    return outputs

In [7]:
tokenized_mini = train_mini.map(tokenize, remove_columns=["messages"])

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [8]:
tokenized_mini

Dataset({
    features: ['data_source', 'input_ids', 'attention_mask'],
    num_rows: 5000
})

In [9]:
merger = Merger(merge_config)
merger.__post_init__()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:31<00:00,  2.79it/s]
2025-01-02 10:38:38,568 - INFO - Initial GPU memory allocated: 0.00 GB
2025-01-02 10:38:38,968 - INFO - Final GPU memory allocated: 0.00 GB
2025-01-02 10:38:38,972 - INFO - Freed GPU memory: 0.00 GB


In [13]:
merger.dtype

torch.bfloat16

In [11]:
merger = merger.to(device="cuda:0", dtype=torch.bfloat16)

In [12]:
set_masks(merger.merger, strategy="uniform", factors=[0.5, 0.5])

Setting up masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 6013.09it/s]


In [19]:
def kl_torch(s_logits, t_logits, idx_masks=None, temperature=1.0):
    """
    Computes KL(s || t) with temperature scaling.
    
    Args:
        s_logits, t_logits: (batch_size, seq_len, vocab)
        temperature: float
        indices: effective indices that exclude pad tokens.

    Returns:
        KL divergence that 
        sum over vocab, then divided by batch_size * seq_len.
        Also do not take into account pad tokens.

    Note:  This is only for reference. It is unused.
    """
    kl_fct = nn.KLDivLoss(reduction="none")
    diff = (
        kl_fct(
            F.log_softmax(t_logits / temperature, dim=-1),
            F.softmax(s_logits / temperature, dim=-1)
        )
        * (temperature) ** 2
    )
    loss = diff.sum(-1).mean()
    # loss = diff.sum(-1)
    return loss

In [14]:
if not torch.cuda.is_available():
    logger.info("CUDA is not available. No GPU memory to free.")
    
initial_memory = torch.cuda.memory_allocated()
logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")
gc.collect()
torch.cuda.empty_cache()

final_memory = torch.cuda.memory_allocated()
logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")

freed_memory = initial_memory - final_memory
logger.info(f"Freed GPU memory: {freed_memory / 1024**3:.2f} GB")

2025-01-02 10:38:48,670 - INFO - Initial GPU memory allocated: 11.97 GB
2025-01-02 10:38:49,057 - INFO - Final GPU memory allocated: 11.97 GB
2025-01-02 10:38:49,059 - INFO - Freed GPU memory: 0.00 GB


In [28]:
collated = data_collator([tokenized_mini[i] for i in range(4)])

2025-01-02 10:50:26,966 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])


In [44]:
collated.keys()

dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])

In [29]:
# --- Custom Trainer with Custom compute_loss ---
def get_entropy_weights(logits_a, logits_b, epsilon=1e-8):
    """
    Calculates entropy-based weights for merging two sets of logits.

    This function efficiently computes the weights for logits_a and logits_b
    based on their respective entropies. It combines the functionality of
    calculating entropy, normalizing weights, and handling potential
    division-by-zero issues.

    Args:
        logits_a: A PyTorch tensor representing the first set of logits.
        logits_b: A PyTorch tensor representing the second set of logits.
        epsilon: A small value to prevent division by zero.

    Returns:
        A tuple containing two tensors: (weight_a, weight_b), representing the
        normalized entropy-based weights for logits_a and logits_b, respectively.
    """

    # Calculate probabilities
    probs_a = F.softmax(logits_a, dim=-1)
    probs_b = F.softmax(logits_b, dim=-1)

    # Calculate entropies with epsilon for numerical stability
    entropy_a = -(probs_a * probs_a.log()).sum(dim=-1, keepdim=True)
    entropy_b = -(probs_b * probs_b.log()).sum(dim=-1, keepdim=True)

    # Calculate inverse entropies (weights)
    inv_entropy_a = 1.0 / (entropy_a + epsilon)
    inv_entropy_b = 1.0 / (entropy_b + epsilon)

    # Normalize weights
    total_inv_entropy = inv_entropy_a + inv_entropy_b
    weight_a = inv_entropy_a / (total_inv_entropy + epsilon)  
    weight_b = inv_entropy_b / (total_inv_entropy + epsilon) 

    return weight_a, weight_b
    
def compute_logits_target(logits_components):
    assert len(logits_components) == 2
    logits_a, logits_b = logits_components
    weight_a, weight_b = get_entropy_weights(logits_a, logits_b)

    logits_target = weight_a * logits_a + weight_b * logits_b

    return logits_target
    
class MergerTrainer(Trainer):
    def compute_loss_general(self, model, inputs, return_outputs=False, num_items_in_batch=None):

        ## Read inputs. Compute a simple mask `effective_idxs`
        ## to exclude non-trainable tokens (e.g., PAD tokens).
        labels = inputs.pop("labels")
        
        effective_idxs = labels.clone()
        effective_idxs[effective_idxs != -100] = 1.0
        effective_idxs[effective_idxs == -100] = 0.0

        ## Forward pass
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        ## Compute targt logits
        logits_target = compute_logits_target(logits_components)

        ## Compute KL divergence
        temperature = 1.0
        kl_fct = nn.KLDivLoss(reduction="none")
        diff = (
            kl_fct(
                F.log_softmax(logits_target / temperature, dim=-1),
                F.softmax(logits_merged / temperature, dim=-1)
            )
            * (temperature) ** 2
        )
        
        ### Exclude non-trainable tokens from taking loss
        loss = (diff * effective_idxs.unsqueeze(dim=-1)).sum(dim=-1)
        loss = (loss / effective_idxs.sum(dim=-1)).mean()

        return (loss, outputs) if return_outputs else loss

In [43]:
12

12

In [39]:
def selective_logits_target(logits_components, data_source):
    # char2idx = {"A": 0, "B": 1, "C": 2, "D": 3}
    # batch_size, _, _ = logits_components[0].shape
    # selected_logits = []
    # for i in range(batch_size):
    #     model_idx = char2idx[data_source[i]]
    #     selected = logits_components[model_idx][i, ...]
    #     selected_logits.append(selected)
    # selected_logits = torch.stack(selected_logits, dim=0)

    # Shape: [num_components, batch_size, seq_len, vocab_size]
    stacked_logits = torch.stack(logits_components)

    # Expand data_source to broadcast properly: [batch_size] -> [batch_size, 1, 1]
    indices = data_source.unsqueeze(-1).unsqueeze(-1)

    # Select logits using advanced indexing
    # Result shape: [batch_size, seq_len, vocab_size]
    selected_logits = stacked_logits[indices]
    
    return selected_logits
    
class TestTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):

        ## Read inputs. Compute a simple mask `effective_idxs`
        ## to exclude non-trainable tokens (e.g., PAD tokens).
        labels = inputs.pop("labels")
        data_source = inputs.pop("data_source")

        effective_idxs = (labels != -100).float()
        
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        ## Compute targt logits
        logits_target = selective_logits_target(logits_components, data_source)

        ## Compute KL divergence
        temperature = 1.0
        kl_fct = nn.KLDivLoss(reduction="none")
        diff = (
            kl_fct(
                F.log_softmax(logits_target / temperature, dim=-1),
                F.softmax(logits_merged / temperature, dim=-1)
            )
            * (temperature) ** 2
        )
        
        ### Exclude non-trainable tokens from taking loss
        loss = (diff * effective_idxs.unsqueeze(dim=-1)).sum(dim=-1)
        loss = (loss / effective_idxs.sum(dim=-1)).mean()

        return (loss, outputs) if return_outputs else loss

In [40]:
@dataclass
class Args:
    model_name: str = "..."  # You can replace this with any causal language model from HuggingFace
    dataset_name: str = "..."  # Replace with your dataset name (e.g., "your_username/your_dataset")
    train_split: str = "train"  # e.g., "train[:80%]" for an 80/20 train/validation split
    validation_split: str = None  # e.g., "train[80%:]"
    output_dir: str = "./trained_masks"
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 2
    learning_rate: float = 2e-5
    num_train_epochs: int = 3
    save_steps: int = 500
    eval_steps: int = 500
    logging_steps: int = 100
    logging_dir: str = "./trained_masks/logs"
    evaluation_strategy: str = "steps"
    report_to: str = None
    remove_unused_columns: bool = False
    # label_names: list = None

args = Args()

# --- Training Arguments ---
training_args = TrainingArguments(
    output_dir=args.output_dir,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    save_steps=args.save_steps,
    evaluation_strategy=args.evaluation_strategy if args.validation_split else "no",
    eval_steps=args.eval_steps if args.validation_split else None,
    logging_steps=args.logging_steps,
    logging_dir=args.logging_dir,
    report_to=args.report_to,  # Enable TensorBoard logging
    remove_unused_columns=args.remove_unused_columns,
    label_names=["labels", "data_source"]
)

# --- Data Collator ---
data_collator = MergerDataCollator(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

# --- Initialize Trainer ---
trainer = TestTrainer(
    model=merger,
    args=training_args,
    train_dataset=tokenized_mini,
    eval_dataset=None,
    data_collator=data_collator,
)




In [41]:
def custom_prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
    """
    Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
    handling potential state.
    """
    logger.info(f">>> data input keys: {list(inputs.keys())}")
    inputs = self._prepare_input(inputs)
    if len(inputs) == 0:
        raise ValueError(
            "The batch received was empty, your model won't be able to train on it. Double-check that your "
            f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
        )
    if self.args.past_index >= 0 and self._past is not None:
        inputs["mems"] = self._past

    logger.info(f">>> data output keys: {list(inputs.keys())}")
    return inputs

def no_remove_unused_columns(
    dataset: 'datasets.Dataset',
    description: Optional[str] = None,
):
    logger.info("Not removing any columns.")
    return dataset

trainer._prepare_inputs = custom_prepare_inputs.__get__(trainer)
trainer._remove_unused_columns = no_remove_unused_columns

In [42]:
# --- Train the Model ---
trainer.train()

2025-01-02 10:54:44,395 - INFO - Not removing any columns.
2025-01-02 10:54:44,570 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])
2025-01-02 10:54:44,632 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])
2025-01-02 10:54:44,761 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])
2025-01-02 10:54:44,796 - INFO - >>> data input keys: ['input_ids', 'attention_mask', 'labels', 'data_source']
2025-01-02 10:54:44,798 - INFO - >>> data output keys: ['input_ids', 'attention_mask', 'labels', 'data_source']


OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/merge/efficient_masks.py", line 662, in forward
    merger_outputs = self.merger(**inputs)
                     ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 945, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 673, in forward
    hidden_states = self.input_layernorm(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/merge/efficient_masks.py", line 412, in forward
    out = out + norm.weight * masked_input
          ~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 24.31 MiB is free. Process 149619 has 23.62 GiB memory in use. Of the allocated memory 22.77 GiB is allocated by PyTorch, and 464.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


In [45]:
Trainer.compute_loss??

[0;31mSignature:[0m
[0mTrainer[0m[0;34m.[0m[0mcompute_loss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minputs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34m"""[0m
[0;34m        How the loss is computed by Trainer. By default, all models return the loss in the first element.[0m
[0;34m[0m
[0;34m        Subclass and override for custom behavi

In [None]:
# --- Save the Model and Tokenizer ---
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)
print(f"Model and tokenizer saved to {args.output_dir}")