In [1]:
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

from constrained_masks import (
    MergerConfig,
    Merger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [4]:
def merger_forward(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(**input_ids)
    return outputs

In [5]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class MergerDataCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, examples):
        """
        copied from DataCollatorForLanguageModeling
        examples: List[Union[List[int], Any, Dict[str, Any]]]
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        assert isinstance(examples[0], Mapping)
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer, examples, return_tensors="pt", 
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch

In [6]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/dont15/models/llama32_smol_rewrite_50k/",
        "/workspace/dont15/models/llama32_smol_summarize_50k/",
        # "/workspace/HUB_LLM/Llama-3.2-3B-Instruct",
    ],
    mode = "vector_input",
    # mode = "scalar",
    constrain_mode = "01"
)
merge_config

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/dont15/models/llama32_smol_rewrite_50k/",
    "/workspace/dont15/models/llama32_smol_summarize_50k/"
  ],
  "transformers_version": "4.47.1"
}

In [7]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token
data_collator = MergerDataCollator(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt"
)

In [8]:
rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")
summarize_train = load_dataset("HuggingFaceTB/smoltalk", "smol-summarize", split="train")
rewrite_train = rewrite_train.add_column(name="data_source", column=["A" for _ in rewrite_train])
summarize_train = summarize_train.add_column(name="data_source", column=["B" for _ in summarize_train])

train_dataset = datasets.concatenate_datasets([rewrite_train, summarize_train])
train_mini = train_dataset.shuffle().select(range(1000))

In [9]:
def tokenize(element):
    templated = tokenizer.apply_chat_template(
        element["messages"], tokenize=False, add_generation_prompt=False
    )
    outputs = tokenizer(
        templated,
        truncation=True,
        max_length=4096,
        add_special_tokens=False
        # return_tensors="pt"
    )
    return outputs

In [10]:
tokenized_mini = train_mini.map(tokenize, remove_columns=["messages"])

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [11]:
merger = Merger(merge_config)
merger.__post_init__()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:23<00:00,  3.05it/s]
2024-12-27 04:03:12,665 - INFO - Initial GPU memory allocated: 0.00 GB
2024-12-27 04:03:13,002 - INFO - Final GPU memory allocated: 0.00 GB
2024-12-27 04:03:13,003 - INFO - Freed GPU memory: 0.00 GB


In [12]:
merger = merger.to(device="cuda:4", dtype=torch.bfloat16)

In [13]:
set_masks(merger.merger, strategy="uniform", factors=[0.5, 0.5])

Setting up masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 7234.25it/s]


In [14]:
# outputs = merger_forward(summarize_text, merger, tokenizer)

In [15]:
# outputs["merger_outputs"].logits[..., :-1, :].contiguous()

In [16]:
# outputs["merger_outputs"].logits == outputs["component_outputs"][0].logits

In [None]:
IGNORE_INDEX = -100

def entropy(logits):
    probs = F.softmax(logits, -1)
    return (probs * probs.log()).sum(-1, keepdim=True).neg()


def merge_op(logits_a, logits_b, weight_a, weight_b):
    # Linear
    return weight_a * logits_a + weight_b * logits_b


def get_entropy_weights(logits_a, logits_b):
    weight_a = 1 / entropy(logits_a)
    weight_b = 1 / entropy(logits_b)

    den = (weight_a + weight_b)
    
    weight_a = weight_a / den
    weight_b = weight_b / den
    
    return weight_a, weight_b
    

def compute_kl_loss(logits_merged, logits_components):
    assert len(logits_components) == 2
    logits_a, logits_b = logits_components
    weight_a, weight_b = get_entropy_weights(logits_a, logits_b)

    logits_target = merge_op(logits_a, logits_b, weight_a, weight_b)

    probs_merged = F.softmax(logits_merged, -1)

    d = probs_merged.log() - F.log_softmax(logits_target, -1)

    ## TODO: need to take sum only on not -100 tokens.
    return (probs_merged * d).sum(-1).mean()


def compute_clm_loss(model, inputs, return_outputs=False):
    """
    Custom compute_loss function for Causal Language Modeling.

    Args:
        model: The model to compute the loss for.
        inputs: A dictionary of inputs as produced by the `collate_fn`.
        return_outputs: Whether or not to return the model outputs in addition to the loss.

    Returns:
        The loss and the model outputs (if `return_outputs=True`).
    """
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    logits = outputs.logits

    # Shift logits and labels for next token prediction
    # We shift the logits to the left and the labels to the right to align them for loss computation
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
    )

    return (loss, outputs) if return_outputs else loss

# --- Custom Trainer with Custom compute_loss ---
class Mergerrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # clm_loss = compute_clm_loss(model, inputs, return_outputs)
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        kl_loss = compute_kl_loss(logits_merged, logits_components)

        loss = kl_loss
        return (loss, outputs) if return_outputs else loss

In [None]:
# --- Configuration ---
MODEL_NAME = "..."  # You can replace this with any causal language model from HuggingFace
DATASET_NAME = "..."  # Replace with your dataset name (e.g., "your_username/your_dataset")
TRAIN_SPLIT = "train" # e.g., "train[:80%]" for an 80/20 train/validation split
VALIDATION_SPLIT = "validation" # e.g., "train[80%:]"
OUTPUT_DIR = "./trained_masks"
PER_DEVICE_TRAIN_BATCH_SIZE = 8
PER_DEVICE_EVAL_BATCH_SIZE = 8
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 3
SAVE_STEPS = 500
EVAL_STEPS = 500
LOGGING_STEPS = 100

# --- Load Tokenizer and Model ---
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token


# --- Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    save_steps=SAVE_STEPS,
    evaluation_strategy="steps" if VALIDATION_SPLIT else "no",
    eval_steps=EVAL_STEPS if VALIDATION_SPLIT else None,
    logging_steps=LOGGING_STEPS,
    logging_dir=f"{OUTPUT_DIR}/logs",
    report_to="tensorboard",  # Enable TensorBoard logging
    remove_unused_columns=False
)

# --- Data Collator ---
data_collator = MergerDataCollator(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

# --- Initialize Trainer ---
trainer = Mergerrainer(
    model=merger,
    args=training_args,
    train_dataset=tokenized_mini,
    eval_dataset=None,
    data_collator=data_collator,
)


In [None]:
# --- Train the Model ---
trainer.train()



In [None]:
# --- Save the Model and Tokenizer ---
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model and tokenizer saved to {OUTPUT_DIR}")