In [18]:
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

from efficient_masks import (
    MergerConfig,
    Merger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [7]:
import torch
batch_size = 4
seq_len = 8
hidden_dim = 100

X = torch.rand(batch_size, seq_len, hidden_dim)
mask = torch.rand(batch_size, seq_len)

In [2]:
selective_mask = torch.tensor([1.0, 0.0, 0.0, 1.0])

In [12]:
selected_logits = []
for i in range(batch_size):
    selected = X[i, ...]
    selected_logits.append(selected)

In [18]:
torch.stack(selected_logits, dim=0).shape

torch.Size([4, 8, 100])

In [11]:
mask = torch.rand(batch_size, seq_len) * torch.tensor([[1.0], [0.0]])
mask[mask != 0] = 1.0

In [16]:
torch.tensor([[1.0], [0.0]]).shape

torch.Size([2, 1])

In [17]:
torch.rand(batch_size, seq_len).shape

torch.Size([2, 8])

In [2]:
# rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")
# summarize_train = load_dataset("HuggingFaceTB/smoltalk", "smol-summarize", split="train")
# rewrite_train = rewrite_train.add_column(name="data_source", column=["A" for _ in rewrite_train])
# summarize_train = summarize_train.add_column(name="data_source", column=["B" for _ in summarize_train])

# train_dataset = datasets.concatenate_datasets([rewrite_train, summarize_train])
# train_mini = train_dataset.shuffle().select(range(5000))
# train_mini.to_json("train_mini.jsonl")

In [3]:
train_mini = load_dataset("json", data_files=["train_mini.jsonl"], split="train")

In [4]:
def merger_forward(texts, model, tokenizer, data_collator):
    tokenized = [tokenizer(
        text,
        truncation=True,
        max_length=4096,
        add_special_tokens=False,
    ) for text in texts]
    inputs = data_collator(tokenized).to(model.device)
    
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs

In [5]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class MergerDataCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, examples):
        """
        copied from DataCollatorForLanguageModeling
        examples: List[Union[List[int], Any, Dict[str, Any]]]
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        if not isinstance(examples[0], Mapping):
            raise ValueError("Data collator only processes list of dictionaries.")

        inputs_ids = []
        for i in range(len(examples)):
            _ = examples[i].pop("attention_mask")
            inputs_ids.append(
                {"input_ids": examples[i].pop("input_ids")}
            )
            
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer, inputs_ids, return_tensors="pt", 
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        
        for key in examples[0]:
            if key in batch:
                raise ValueError(f"`{key}` feature is collated. Overriding it with its initial values is prohibitted.")
            else:
                batch[key] = [x[key] for x in examples]
        return batch

In [6]:
merge_config = MergerConfig(
    model_paths = [
        "nguyenthanhdo/llama32_smol_rewrite_50k",
        "nguyenthanhdo/llama32_smol_summarize_50k",
        # "/workspace/HUB_LLM/Llama-3.2-3B-Instruct",
    ],
    mode = "vector_input",
    # mode = "scalar",
    constrain_mode = "01"
)
merge_config

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "nguyenthanhdo/llama32_smol_rewrite_50k",
    "nguyenthanhdo/llama32_smol_summarize_50k"
  ],
  "transformers_version": "4.46.3"
}

In [7]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token
data_collator = MergerDataCollator(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt"
)

In [8]:
def tokenize(element):
    templated = tokenizer.apply_chat_template(
        element["messages"], tokenize=False, add_generation_prompt=False
    )
    outputs = tokenizer(
        templated,
        truncation=True,
        max_length=4096,
        add_special_tokens=False,
        # return_tensors="pt"
    )
    return outputs

In [9]:
tokenized_mini = train_mini.map(tokenize, remove_columns=["messages"])

In [10]:
tokenized_mini

Dataset({
    features: ['data_source', 'input_ids', 'attention_mask'],
    num_rows: 5000
})

In [11]:
merger = Merger(merge_config)
merger.__post_init__()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:50<00:00,  5.04it/s]
2024-12-30 10:05:23,898 - INFO - Initial GPU memory allocated: 0.00 GB
2024-12-30 10:05:24,168 - INFO - Final GPU memory allocated: 0.00 GB
2024-12-30 10:05:24,168 - INFO - Freed GPU memory: 0.00 GB


In [12]:
merger.dtype

torch.bfloat16

In [13]:
merger = merger.to(device="cuda:0", dtype=torch.bfloat16)

In [14]:
set_masks(merger.merger, strategy="uniform", factors=[0.5, 0.5])

Setting up masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 12809.87it/s]


In [15]:
IGNORE_INDEX = -100

def compute_clm_loss(model, inputs, return_outputs=False):
    """
    Custom compute_loss function for Causal Language Modeling.

    Args:
        model: The model to compute the loss for.
        inputs: A dictionary of inputs as produced by the `collate_fn`.
        return_outputs: Whether or not to return the model outputs in addition to the loss.

    Returns:
        The loss and the model outputs (if `return_outputs=True`).
    """

    labels = inputs.pop("labels")
    outputs = model(**inputs)
    logits = outputs.logits

    # Shift logits and labels for next token prediction
    # We shift the logits to the left and the labels to the right to align them for loss computation
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
    )

    return (loss, outputs) if return_outputs else loss

In [57]:
def compute_kl_loss(logits_merged, logits_target):

    probs_merged = F.softmax(logits_merged, -1)

    d = probs_merged.log() - F.log_softmax(logits_target, -1)

    ## TODO: need to take sum only on not -100 tokens.
    return (probs_merged * d).sum(-1).mean()
    # return probs_merged * d

def kl_torch(s_logits, t_logits, idx_masks=None, temperature=1.0):
    """
    Computes KL(s || t) with temperature scaling.
    
    Args:
        s_logits, t_logits: (batch_size, seq_len, vocab)
        temperature: float
        indices: effective indices that exclude pad tokens.

    Returns:
        KL divergence that 
        sum over vocab, then divided by batch_size * seq_len.
        Also do not take into account pad tokens.
    """
    kl_fct = nn.KLDivLoss(reduction="none")
    diff = (
        kl_fct(
            F.log_softmax(t_logits / temperature, dim=-1),
            F.softmax(s_logits / temperature, dim=-1)
        )
        * (temperature) ** 2
    )
    loss = diff.sum(-1).mean()
    # loss = diff.sum(-1)
    return loss

In [18]:
texts = [tokenizer.apply_chat_template(
    train_mini[i]["messages"], tokenize=False, add_generation_prompt=False
) for i in range(8)]
texts[0]

"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou're an AI assistant for text re-writing. Rewrite the input text to make it more friendly and approachable while maintaining its main points.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nMr. Thompson,\n\nI have been waiting for the finalized schedule for the community outreach program, which was supposed to be ready by last week. I understand that you have a lot on your plate, but your lack of communication is making it very difficult for me to plan effectively. I need this information as soon as possible to ensure we meet our grant requirements.\n\nBest regards,\nEmily Carter<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi Mr. Thompson,\n\nI hope you're doing well! I wanted to check in about the finalized schedule for the community outreach program. I know things can get really busy, and I appreciate all the work you're doing. However, I'm finding it a bit challenging to plan without the schedule, w

In [38]:
collated = data_collator([tokenized_mini[i] for i in range(2)])

In [39]:
collated = collated.to(merger.device)

In [21]:
# with torch.no_grad():
outputs_1 = merger.models[0](**collated)

In [22]:
# with torch.no_grad():
outputs_2 = merger.models[1](**collated)

In [23]:
# with torch.no_grad():
merger_outputs = merger.merger(**collated)

In [24]:
# with torch.no_grad():
#     outputs = merger_forward(texts, merger, tokenizer, data_collator)

In [25]:
if not torch.cuda.is_available():
    logger.info("CUDA is not available. No GPU memory to free.")
    
initial_memory = torch.cuda.memory_allocated()
logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")
gc.collect()
torch.cuda.empty_cache()

final_memory = torch.cuda.memory_allocated()
logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")

freed_memory = initial_memory - final_memory
logger.info(f"Freed GPU memory: {freed_memory / 1024**3:.2f} GB")

2024-12-30 10:05:28,317 - INFO - Initial GPU memory allocated: 14.31 GB
2024-12-30 10:05:28,671 - INFO - Final GPU memory allocated: 14.31 GB
2024-12-30 10:05:28,672 - INFO - Freed GPU memory: 0.00 GB


In [71]:
test_kl_out = kl_torch(outputs_1.logits, outputs_2.logits)

In [73]:
effective_idxs = collated['labels'].clone()
effective_idxs[effective_idxs != -100] = 1.0
effective_idxs[effective_idxs == -100] = 0.0

In [74]:
(test_kl_out * effective_idxs).sum(dim=-1) / effective_idxs.sum(dim=-1)

tensor([2.3594, 1.7500], device='cuda:0', dtype=torch.bfloat16)

In [77]:
(test_kl_out * effective_idxs).sum(dim=-1)

tensor([528., 362.], device='cuda:0', dtype=torch.bfloat16)

In [76]:
effective_idxs.sum(dim=-1)

tensor([224, 207], device='cuda:0')

In [25]:
merger_outputs

CausalLMOutputWithPast(loss=tensor(2.9416, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[ 5.2812,  7.5625, 13.3125,  ..., -6.1875, -6.1875, -6.1875],
         [ 4.5625, -0.0386,  0.2432,  ...,  1.7500,  1.7500,  1.7500],
         [ 3.8125,  3.1719,  2.7500,  ..., -0.8438, -0.8438, -0.8438],
         ...,
         [ 6.0000,  5.3438,  3.9375,  ..., -8.1250, -8.1250, -8.1250],
         [ 2.3906,  3.9062,  2.3281,  ..., -1.8438, -1.8438, -1.8438],
         [ 0.4688, -1.8281, -0.2285,  ...,  3.2812,  3.2812,  3.2812]],

        [[ 5.2812,  7.5625, 13.3125,  ..., -6.1875, -6.1875, -6.1875],
         [ 4.5625, -0.0386,  0.2432,  ...,  1.7500,  1.7500,  1.7500],
         [ 3.8125,  3.1719,  2.7500,  ..., -0.8438, -0.8438, -0.8438],
         ...,
         [ 1.4062,  2.7188,  3.5000,  ..., -1.7188, -1.7188, -1.7188],
         [ 1.4453,  2.7969,  3.4219,  ..., -1.7344, -1.7344, -1.7344],
         [ 1.4219,  2.7969,  3.3438,  ..., -1.7656, -1.7656, -1.7656]]],
       device='cuda:

In [33]:
collated["attention_mask"][0] == collated["attention_mask"][1]

tensor([True, True, True,  ..., True, True, True])

In [32]:
collated.keys()

dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])

In [153]:
outputs["merger_outputs"].logits.shape

torch.Size([8, 1336, 128256])

In [28]:
merger_logits = outputs["merger_outputs"].logits
components_logits = [x.logits for x in outputs["component_outputs"]]

In [47]:
components_logits[0].shape

torch.Size([1, 320, 128256])

In [88]:
kl_torch(components_logits[1].float(), components_logits[1].float())

tensor(-1.2190e-10, device='cuda:4')

In [91]:
compute_kl_loss(components_logits[1].float(), components_logits[1].float())

tensor(-4.3787e-10, device='cuda:4')

In [28]:
compute_kl_loss(*components_logits) == kl_torch(*components_logits)

NameError: name 'compute_kl_loss' is not defined

tensor(1.8359, device='cuda:4', dtype=torch.bfloat16)

In [None]:
# --- Custom Trainer with Custom compute_loss ---
def get_entropy_weights(logits_a, logits_b, epsilon=1e-8):
    """
    Calculates entropy-based weights for merging two sets of logits.

    This function efficiently computes the weights for logits_a and logits_b
    based on their respective entropies. It combines the functionality of
    calculating entropy, normalizing weights, and handling potential
    division-by-zero issues.

    Args:
        logits_a: A PyTorch tensor representing the first set of logits.
        logits_b: A PyTorch tensor representing the second set of logits.
        epsilon: A small value to prevent division by zero.

    Returns:
        A tuple containing two tensors: (weight_a, weight_b), representing the
        normalized entropy-based weights for logits_a and logits_b, respectively.
    """

    # Calculate probabilities
    probs_a = F.softmax(logits_a, dim=-1)
    probs_b = F.softmax(logits_b, dim=-1)

    # Calculate entropies with epsilon for numerical stability
    entropy_a = -(probs_a * probs_a.log()).sum(dim=-1, keepdim=True)
    entropy_b = -(probs_b * probs_b.log()).sum(dim=-1, keepdim=True)

    # Calculate inverse entropies (weights)
    inv_entropy_a = 1.0 / (entropy_a + epsilon)
    inv_entropy_b = 1.0 / (entropy_b + epsilon)

    # Normalize weights
    total_inv_entropy = inv_entropy_a + inv_entropy_b
    weight_a = inv_entropy_a / (total_inv_entropy + epsilon)  
    weight_b = inv_entropy_b / (total_inv_entropy + epsilon) 

    return weight_a, weight_b
    
def compute_logits_target(logits_components):
    assert len(logits_components) == 2
    logits_a, logits_b = logits_components
    weight_a, weight_b = get_entropy_weights(logits_a, logits_b)

    logits_target = weight_a * logits_a + weight_b * logits_b

    return logits_target
    
class Mergerrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        ## Read inputs. Compute a simple mask `effective_idxs`
        ## to exclude non-trainable tokens (e.g., PAD tokens).
        labels = inputs.pop("labels")
        
        effective_idxs = labels.clone()
        effective_idxs[effective_idxs != -100] = 1.0
        effective_idxs[effective_idxs == -100] = 0.0

        ## Forward pass
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        ## Compute targt logits
        logits_target = compute_logits_target(logits_components)

        ## Compute KL divergence
        temperature = 1.0
        kl_fct = nn.KLDivLoss(reduction="none")
        diff = (
            kl_fct(
                F.log_softmax(logits_target / temperature, dim=-1),
                F.softmax(logits_merged / temperature, dim=-1)
            )
            * (temperature) ** 2
        )
        
        ### Exclude non-trainable tokens from taking loss
        loss = (diff * effective_idxs).sum(dim=-1)
        loss = (loss / effective_idxs.sum(dim=-1)).mean()
        
        return (loss, outputs) if return_outputs else loss

In [None]:
from dataclasses import dataclass

@dataclass
class Args:
    model_name: str = "..."  # You can replace this with any causal language model from HuggingFace
    dataset_name: str = "..."  # Replace with your dataset name (e.g., "your_username/your_dataset")
    train_split: str = "train"  # e.g., "train[:80%]" for an 80/20 train/validation split
    validation_split: str = None  # e.g., "train[80%:]"
    output_dir: str = "./trained_masks"
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 2
    learning_rate: float = 2e-5
    num_train_epochs: int = 3
    save_steps: int = 500
    eval_steps: int = 500
    logging_steps: int = 100
    logging_dir: str = "./trained_masks/logs"
    evaluation_strategy: str = "steps"
    report_to: str = "tensorboard"
    remove_unused_columns: bool = False

args = Args()


# --- Load Tokenizer and Model ---
# tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# model = AutoModelForCausalLM.from_pretrained(args.model_name)

# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token

# --- Training Arguments ---
training_args = TrainingArguments(
    output_dir=args.output_dir,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    save_steps=args.save_steps,
    evaluation_strategy=args.evaluation_strategy if args.validation_split else "no",
    eval_steps=args.eval_steps if args.validation_split else None,
    logging_steps=args.logging_steps,
    logging_dir=args.logging_dir,
    report_to=args.report_to,  # Enable TensorBoard logging
    remove_unused_columns=args.remove_unused_columns
)

# --- Data Collator ---
data_collator = MergerDataCollator(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

# --- Initialize Trainer ---
trainer = Mergerrainer(
    model=merger,
    args=training_args,
    train_dataset=tokenized_mini,
    eval_dataset=None,
    data_collator=data_collator,
)


In [None]:
# --- Train the Model ---
trainer.train()



In [None]:
# --- Save the Model and Tokenizer ---
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model and tokenizer saved to {OUTPUT_DIR}")