In [1]:
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from accurate_masks import (
# from efficient_masks import (
    MergerConfig,
    Merger,
    NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [2]:
import os
# Option 1: Set specific GPU devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

In [3]:
# rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")
summarize_train = load_dataset("HuggingFaceTB/smoltalk", "smol-summarize", split="train")
# rewrite_train = rewrite_train.add_column(name="data_source", column=[0 for _ in rewrite_train])
summarize_train = summarize_train.add_column(name="data_source", column=[1 for _ in summarize_train])

# train_dataset = datasets.concatenate_datasets([rewrite_train, summarize_train])
# train_mini = train_dataset.shuffle().select(range(5000))
# train_mini.to_json("train_mini.jsonl")

In [8]:
# train_mini = load_dataset("json", data_files=["train_mini.jsonl"], split="train")

Generating train split: 0 examples [00:00, ? examples/s]

In [4]:
train_mini = summarize_train.shuffle(seed=42).select(range(5000))

In [5]:
train_mini

Dataset({
    features: ['messages', 'data_source'],
    num_rows: 5000
})

In [6]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class MergerDataCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, examples):
        """
        copied from DataCollatorForLanguageModeling
        examples: List[Union[List[int], Any, Dict[str, Any]]]
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        if not isinstance(examples[0], Mapping):
            raise ValueError("Data collator only processes list of dictionaries.")

        inputs_ids = []
        data_sources = []
        for i in range(len(examples)):
            _ = examples[i].pop("attention_mask")
            inputs_ids.append({"input_ids": examples[i].pop("input_ids")})
            data_sources.append(examples[i].pop("data_source"))
            
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer, inputs_ids, return_tensors="pt", 
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        # Handle data_source - convert to tensor
        batch["data_source"] = torch.tensor(
            [src for src in data_sources], dtype=torch.long
        )
        
        for key in examples[0]:
            if key in batch:
                raise ValueError(f"`{key}` feature is collated. Overriding it with its initial values is prohibitted.")
            else:
                batch[key] = [x[key] for x in examples]
        logger.info_once(f">>> Collator output keys: {batch.keys()}")
        return batch

In [7]:
merge_config = MergerConfig(
    model_paths = [
        "nguyenthanhdo/llama32_smol_rewrite_50k",
        "nguyenthanhdo/llama32_smol_summarize_50k",
        # "/workspace/HUB_LLM/Llama-3.2-3B-Instruct",
    ],
    mode = "vector_input",
    # mode = "scalar",
    constrain_mode = "01"
)
merge_config

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "nguyenthanhdo/llama32_smol_rewrite_50k",
    "nguyenthanhdo/llama32_smol_summarize_50k"
  ],
  "transformers_version": "4.46.3"
}

In [8]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token
data_collator = MergerDataCollator(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt"
)

In [9]:
def tokenize(element):
    templated = tokenizer.apply_chat_template(
        element["messages"], tokenize=False, add_generation_prompt=False
    )
    outputs = tokenizer(
        templated,
        truncation=True,
        max_length=2048,
        add_special_tokens=False,
        # return_tensors="pt"
    )
    return outputs

In [10]:
tokenized_mini = train_mini.map(tokenize, remove_columns=["messages"])

In [11]:
tokenized_mini

Dataset({
    features: ['data_source', 'input_ids', 'attention_mask'],
    num_rows: 5000
})

In [12]:
merger = NewMerger.from_pretrained(
    None, 
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

2025-01-04 15:02:00,564 - INFO - Creating merger with dummy weights ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:27<00:00,  9.11it/s]


In [33]:
# merger = merger.to(device="cuda:0", dtype=torch.bfloat16)

In [13]:
set_masks(merger.merger, strategy="uniform", factors=[0.95, 0.05])

Setting up masks: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 17887.50it/s]


In [14]:
def selective_logits_target(logits_components, data_source):
    # Shape: [num_components, batch_size, seq_len, vocab_size]
    stacked_logits = torch.stack(logits_components)

    # Expand data_source to broadcast properly: [batch_size] -> [batch_size, 1, 1]
    indices = data_source.unsqueeze(-1).unsqueeze(-1)

    # Select logits using advanced indexing
    # Result shape: [batch_size, seq_len, vocab_size]
    selected_logits = stacked_logits[indices]
    
    return selected_logits
    
class TestTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):

        ## Read inputs. Compute a simple mask `effective_idxs`
        ## to exclude non-trainable tokens (e.g., PAD tokens).
        labels = inputs.pop("labels")
        data_source = inputs.pop("data_source")

        effective_idxs = (labels != -100).float().unsqueeze(dim=-1)
        
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        ## Compute targt logits
        logits_target = selective_logits_target(logits_components, data_source)

        ## Compute KL divergence
        temperature = 1.0
        kl_fct = nn.KLDivLoss(reduction="none")
        diff = (
            kl_fct(
                F.log_softmax(logits_target / temperature, dim=-1),
                F.softmax(logits_merged / temperature, dim=-1)
            )
            * (temperature) ** 2
        )
        
        ### Exclude non-trainable tokens from taking loss
        loss = (diff * effective_idxs).sum(dim=-1)
        loss = (loss / effective_idxs.sum(dim=1)).mean()

        return (loss, outputs) if return_outputs else loss

In [15]:
@dataclass
class Args:
    model_name: str = "..."  # You can replace this with any causal language model from HuggingFace
    dataset_name: str = "..."  # Replace with your dataset name (e.g., "your_username/your_dataset")
    train_split: str = "train"  # e.g., "train[:80%]" for an 80/20 train/validation split
    validation_split: str = None  # e.g., "train[80%:]"
    output_dir: str = "./trained_masks"
    per_device_train_batch_size: int = 2
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 16
    learning_rate: float = 3e-5
    num_train_epochs: int = 10
    save_steps: int = 1000
    eval_steps: int = 5000
    logging_steps: int = 10
    logging_dir: str = "./trained_masks/logs"
    evaluation_strategy: str = "steps"
    report_to: str = None
    remove_unused_columns: bool = False
    # label_names: list = None

args = Args()

# --- Training Arguments ---
training_args = TrainingArguments(
    output_dir=args.output_dir,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    save_steps=args.save_steps,
    evaluation_strategy=args.evaluation_strategy if args.validation_split else "no",
    eval_steps=args.eval_steps if args.validation_split else None,
    logging_steps=args.logging_steps,
    logging_dir=args.logging_dir,
    report_to=args.report_to,  # Enable TensorBoard logging
    remove_unused_columns=args.remove_unused_columns,
)

# --- Data Collator ---
data_collator = MergerDataCollator(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

# --- Initialize Trainer ---
trainer = TestTrainer(
    model=merger,
    args=training_args,
    train_dataset=tokenized_mini,
    eval_dataset=None,
    data_collator=data_collator,
)



In [16]:
initial_memory = torch.cuda.memory_allocated()
logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")
gc.collect()
torch.cuda.empty_cache()

final_memory = torch.cuda.memory_allocated()
logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")

freed_memory = initial_memory - final_memory
logger.info(f"Freed GPU memory: {freed_memory / 1024**3:.2f} GB")

2025-01-04 15:02:59,533 - INFO - Initial GPU memory allocated: 11.97 GB
2025-01-04 15:02:59,840 - INFO - Final GPU memory allocated: 11.97 GB
2025-01-04 15:02:59,842 - INFO - Freed GPU memory: 0.00 GB


In [None]:
# --- Train the Model ---
trainer.train()

2025-01-04 15:03:03,370 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])


Step,Training Loss
10,0.0033
20,0.0035
30,0.0037
40,0.0039
50,0.0036
60,0.0035
70,0.0031
80,0.0037
90,0.0033
100,0.0035
