In [1]:
import math
from typing import List, Optional, Tuple, Union
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    Trainer
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast
)

from constrained_masks import (
    MergerConfig,
    Merger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [2]:
def merger_forward(text, model, tokenizer):
    input_ids = tokenizer(text, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(**input_ids)
    return outputs

In [3]:
merge_config = MergerConfig(
    model_paths = [
        "/workspace/dont15/models/llama32_smol_rewrite_50k/",
        "/workspace/dont15/models/llama32_smol_summarize_50k/",
        # "/workspace/HUB_LLM/Llama-3.2-3B-Instruct",
    ],
    mode = "vector_input",
    # mode = "scalar",
    constrain_mode = "01"
)
merge_config

MergerConfig {
  "constrain_mode": "01",
  "mode": "vector_input",
  "model_paths": [
    "/workspace/dont15/models/llama32_smol_rewrite_50k/",
    "/workspace/dont15/models/llama32_smol_summarize_50k/"
  ],
  "transformers_version": "4.47.1"
}

In [4]:
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])

In [5]:
rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")
summarize_train = load_dataset("HuggingFaceTB/smoltalk", "smol-summarize", split="train")

In [6]:
rewrite_example = rewrite_train[19]
rewrite_messages = rewrite_example["messages"]
rewrite_text = tokenizer.apply_chat_template(
    rewrite_messages[:-1], tokenize=False, add_generation_prompt=True
)

summarize_example = summarize_train[20]
summarize_messages = summarize_example["messages"]
summarize_text = tokenizer.apply_chat_template(
    summarize_messages[:-1], tokenize=False, add_generation_prompt=True
)

In [7]:
merger = Merger(merge_config)
# merger = merger.to(device="cuda:7", dtype=torch.bfloat16)
merger.__post_init__()
# merger = merger.to(device="cuda:7", dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:23<00:00,  3.06it/s]
2024-12-26 09:51:34,580 - INFO - Initial GPU memory allocated: 0.00 GB
2024-12-26 09:51:34,781 - INFO - Final GPU memory allocated: 0.00 GB
2024-12-26 09:51:34,782 - INFO - Freed GPU memory: 0.00 GB


In [8]:
merger = merger.to(device="cuda:4", dtype=torch.bfloat16)

In [9]:
outputs = merger_forward(summarize_text, merger, tokenizer)



In [16]:
outputs["merger_outputs"].logits[..., :-1, :].contiguous()

tensor([[[ 15.8750,  19.5000,  39.2500,  ..., -24.6250, -24.6250, -24.6250],
         [ 15.8750,  19.5000,  39.2500,  ..., -24.6250, -24.6250, -24.6250],
         [ -1.2578,   6.5625,   3.9844,  ...,   0.8789,   0.8789,   0.8750],
         ...,
         [ -0.4434,   0.1533,  -9.6250,  ...,   1.3828,   1.3828,   1.3828],
         [ -1.3281,  -2.1094, -10.1875,  ...,   0.6875,   0.6875,   0.6875],
         [ -4.5625,   1.0078,  -6.3750,  ...,   3.1406,   3.1406,   3.1250]]],
       device='cuda:4', dtype=torch.bfloat16)

In [17]:
outputs["merger_outputs"].logits

tensor([[[ 15.8750,  19.5000,  39.2500,  ..., -24.6250, -24.6250, -24.6250],
         [ 15.8750,  19.5000,  39.2500,  ..., -24.6250, -24.6250, -24.6250],
         [ -1.2578,   6.5625,   3.9844,  ...,   0.8789,   0.8789,   0.8750],
         ...,
         [ -1.3281,  -2.1094, -10.1875,  ...,   0.6875,   0.6875,   0.6875],
         [ -4.5625,   1.0078,  -6.3750,  ...,   3.1406,   3.1406,   3.1250],
         [  4.3438,  -2.7031,  -2.9375,  ...,   0.3398,   0.3418,   0.3398]]],
       device='cuda:4', dtype=torch.bfloat16)

In [None]:
IGNORE_INDEX = -100

def clm_loss(model, inputs, return_outputs=False):
    """
    Custom compute_loss function for Causal Language Modeling.

    Args:
        model: The model to compute the loss for.
        inputs: A dictionary of inputs as produced by the `collate_fn`.
        return_outputs: Whether or not to return the model outputs in addition to the loss.

    Returns:
        The loss and the model outputs (if `return_outputs=True`).
    """
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    logits = outputs.logits

    # Shift logits and labels for next token prediction
    # We shift the logits to the left and the labels to the right to align them for loss computation
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
    )

    return (loss, outputs) if return_outputs else loss


In [18]:
torch.nn.CrossEntropyLoss?

[0;31mInit signature:[0m
[0mtorch[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mCrossEntropyLoss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mweight[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msize_average[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mignore_index[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;34m-[0m[0;36m100[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduce[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreduction[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'mean'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlabel_smoothing[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
This criterion computes the cross entropy loss between input logits
and target.



In [15]:
Trainer.compute_loss??

[0;31mSignature:[0m
[0mTrainer[0m[0;34m.[0m[0mcompute_loss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minputs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;34m"""[0m
[0;34m        How the loss is computed by Trainer. By default, all models return the loss in the first element.[0m
[0;34m[0m
[0;34m        Subclass and override for custom behavi