In [1]:
"""
Model merging training implementation using PyTorch and Transformers.
Implements custom data collation and training for merged language models.
"""
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union, Mapping
from abc import ABC, abstractmethod

import datasets
import torch
import torch.nn.functional as F
import safetensors
import numpy as np
import torch.nn as nn
import logging
import copy
import gc

from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    PreTrainedTokenizerBase,
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaConfig,
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    HfArgumentParser,
    default_data_collator,
    is_torch_xla_available,
    set_seed,
)

from transformers.utils import CONFIG_NAME

# from accurate_masks import (
# from efficient_masks import (
from merger import (
    MergerConfig,
    # Merger,
    NewMerger,
    init_masks,
    set_masks
)

from utils import (
    generate, 
    get_hidden_states, 
    get_logits,
    free_memory
)
# Configure logger
logging.basicConfig(
    level=logging.INFO, 
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [2]:
import os
# Option 1: Set specific GPU devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

In [3]:
class DataProcessor:
    """Handles data loading and preprocessing."""
    
    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        self.tokenizer = tokenizer
    
    def load_dataset(self):
        """Load and prepare the training dataset."""
        summarize_train = load_dataset(
            "HuggingFaceTB/smoltalk",
            "smol-summarize",
            split="train"
        )
        summarize_train = summarize_train.add_column(
            name="data_source",
            column=[1 for _ in summarize_train]
        )
        return summarize_train.shuffle(seed=42).select(range(30000))
    
    def tokenize(self, element):
        """Tokenize a single element from the dataset."""
        templated = self.tokenizer.apply_chat_template(
            element["messages"],
            tokenize=False,
            add_generation_prompt=False
        )
        return self.tokenizer(
            templated,
            truncation=True,
            max_length=2048,
            add_special_tokens=False
        )

In [4]:
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
    """
    Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
    """

    # To avoid errors when using Feature extractors
    if not hasattr(tokenizer, "deprecation_warnings"):
        return tokenizer.pad(*pad_args, **pad_kwargs)

    # Save the state of the warning, then disable it
    warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
    tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

    try:
        padded = tokenizer.pad(*pad_args, **pad_kwargs)
    finally:
        # Restore the state of the warning.
        tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

    return padded

@dataclass
class MergerDataCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, examples):
        """
        copied from DataCollatorForLanguageModeling
        examples: List[Union[List[int], Any, Dict[str, Any]]]
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        if not isinstance(examples[0], Mapping):
            raise ValueError("Data collator only processes list of dictionaries.")

        inputs_ids = []
        data_sources = []
        for i in range(len(examples)):
            _ = examples[i].pop("attention_mask")
            inputs_ids.append({"input_ids": examples[i].pop("input_ids")})
            data_sources.append(examples[i].pop("data_source"))
            
        batch = pad_without_fast_tokenizer_warning(
            self.tokenizer, inputs_ids, return_tensors="pt", 
            pad_to_multiple_of=self.pad_to_multiple_of
        )

        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        # Handle data_source - convert to tensor
        batch["data_source"] = torch.tensor(
            [src for src in data_sources], dtype=torch.long
        )
        
        for key in examples[0]:
            if key in batch:
                raise ValueError(
                    f"`{key}` feature is collated. "
                    "Overriding it with its initial values is prohibitted."
                )
            else:
                batch[key] = [x[key] for x in examples]
        logger.info_once(f">>> Collator output keys: {batch.keys()}")
        return batch

In [14]:
def selective_logits_target(logits_components, data_source):
    """Select appropriate logits based on data source."""
    # logits_components ~ [(batch_size, seq_len, vocab_size) * n_components]
    # stacked_logits.shape = (n_components, batch_size, seq_len, vocab_size)
    # data_source.shape == (batch_size,)
    # indices.shape == (batch_size, 1, 1)
    stacked_logits = torch.stack(logits_components)
    indices = data_source.unsqueeze(-1).unsqueeze(-1)
    return stacked_logits[indices]

def selective_logits_target(logits_components, data_source):
    """Select appropriate logits based on data source."""

    logits_target = torch.empty_like(logits_components[0])
    for idx, source in enumerate(data_source):
        logits_target[idx] = logits_components[source][idx]

    return logits_target

def masked_kl_div(logits_a, logits_b, mask, temperature=1.0):
    # (batch_size, seq_len, vocab_size) -> (batch_size * seq_len, vocab_size)
    logits_a = logits_a.view(-1, logits_a.size(-1)) / temperature
    logits_b = logits_b.view(-1, logits_b.size(-1)) / temperature

    # (batch_size * seq_len,)
    mask = mask.view(-1)

    assert mask.size(0) == logits_a.size(0)

    log_probs_a = nn.functional.log_softmax(logits_a, dim=-1)
    log_probs_b = nn.functional.log_softmax(logits_b, dim=-1)

    # (batch_size * seq_len, vocab_size) -> (batch_size * seq_len)
    div = log_probs_a.exp() * (log_probs_a - log_probs_b)
    div = div.sum(-1)

    ## taking average on effective tokens.
    div = (div * mask).sum() / mask.sum() * (temperature ** 2)
    return div

class MergerTrainer(Trainer):
    """Custom trainer for merged model training."""
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        data_source = inputs.pop("data_source")
        effective_idxs = (labels != -100).float().unsqueeze(dim=-1)
        
        outputs = model(**inputs)
        logits_merged = outputs["merger_outputs"].logits
        logits_components = [x.logits for x in outputs["components_outputs"]]

        # Compute target logits and KL divergence
        logits_target = selective_logits_target(logits_components, data_source)
        
        # temperature = 1.0
        # kl_fct = nn.KLDivLoss(reduction="none")
        # diff = (
        #     kl_fct(
        #         F.log_softmax(logits_target / temperature, dim=-1),
        #         F.softmax(logits_merged / temperature, dim=-1)
        #     )
        #     * (temperature) ** 2
        # )
        
        # # Calculate final loss
        # loss = (diff * effective_idxs).sum(dim=-1)
        # loss = (loss / effective_idxs.sum(dim=1)).mean()


        loss = masked_kl_div(logits_merged, logits_target, effective_idxs)
        # logger.info(f">>> mdfk: {loss}")
        return (loss, outputs) if return_outputs else loss

    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        assert model is not None, (
            "Must pass an initialized model to trainer instead of model path."
        )
        # Look for trainable parameters file
        masks_file = os.path.join(resume_from_checkpoint, "masks.safetensors")
        if not os.path.isfile(trainable_params_file):
            masks_file = os.path.join(resume_from_checkpoint, "masks.bin")
        
        if not os.path.isfile(masks_file):
            raise ValueError(
                f"Can't find trainable parameters file in {resume_from_checkpoint}. "
                "Expected either masks.safetensors or masks.bin"
            )
    
        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behaviors."
                )
    
        if os.path.isfile(masks_file):
            weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
            # If the model is on the GPU, it still works!
            # We load the model state dict on the CPU to avoid an OOM error.
            if self.args.save_safetensors and masks_file.endswith(".safetensors"):
                state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
            else:
                state_dict = torch.load(
                    masks_file,
                    map_location="cpu",
                    **weights_only_kwarg,
                )
    
            # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
            # which takes *args instead of **kwargs
            load_result = model.load_state_dict(state_dict, False)
            if len(load_result.missing_keys) != 0:
                logger.info(
                    "There were missing keys in the checkpoint model loaded. "
                    "However, this is an expected behavior since we are only "
                    "loading partial weights (masks)."
                )
            # release memory
            del state_dict
            gc.collect()

In [15]:
@dataclass
class Args:
    model_name: str = "..."  # You can replace this with any causal language model from HuggingFace
    dataset_name: str = "..."  # Replace with your dataset name (e.g., "your_username/your_dataset")
    train_split: str = "train"  # e.g., "train[:80%]" for an 80/20 train/validation split
    validation_split: str = None  # e.g., "train[80%:]"
    output_dir: str = "./trained_masks"
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 32
    learning_rate: float = 5e-3
    num_train_epochs: int = 1
    save_steps: int = 100
    eval_steps: int = 5000
    logging_steps: int = 10
    logging_dir: str = "./trained_masks/logs"
    evaluation_strategy: str = "steps"
    report_to: str = None
    remove_unused_columns: bool = False
    logging_first_step: bool = True
    gradient_checkpointing: bool = False

In [7]:
# Initialize configuration
merge_config = MergerConfig(
    model_paths=[
        "nguyenthanhdo/llama32_smol_rewrite_50k",
        "nguyenthanhdo/llama32_smol_summarize_50k",
    ],
    mode="vector_input",
    constrain_mode="identity"
)

# Setup tokenizer and data processing
tokenizer = AutoTokenizer.from_pretrained(merge_config.model_paths[0])
tokenizer.pad_token = tokenizer.eos_token
data_processor = DataProcessor(tokenizer)
train_dataset = data_processor.load_dataset()
tokenized_dataset = train_dataset.map(
    data_processor.tokenize,
    remove_columns=["messages"]
)

In [8]:
# Initialize merger model
merger = NewMerger.from_pretrained(
    None,
    merge_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
set_masks(merger.merger, strategy="uniform", factors=[0.99, 0.01])

2025-01-06 09:58:12,428 - INFO - Creating merger with dummy weights ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Initializing masks: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:27<00:00,  9.19it/s]
Setting up masks: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:00<00:00, 34482.62it/s]


In [9]:
# from safetensors.torch import load_file as safe_load_file
# state_dict = safe_load_file("./trained_masks/model.safetensors")

In [15]:
# trainable_count = 0
# for k, v in state_dict.items():
#     trainable_count += v.numel()
# trainable_count

1503346

In [10]:
# set_masks(merger.merger, strategy="uniform", factors=[0.8, 0.2])

In [17]:
# merger.load_masks("./trained_masks")

In [9]:
# merger.merger.model.embed_tokens.get_raw_masks()

In [12]:
# merger = NewMerger.from_pretrained(
#     "./trained_masks",
#     merge_config,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     attn_implementation="flash_attention_2",
# )

In [16]:
# merger.save_pretrained("./trained_masks")

In [20]:
# Monitor memory usage
initial_memory = torch.cuda.memory_allocated()
logger.info(f"Initial GPU memory allocated: {initial_memory / 1024**3:.2f} GB")

gc.collect()
torch.cuda.empty_cache()

final_memory = torch.cuda.memory_allocated()
logger.info(f"Final GPU memory allocated: {final_memory / 1024**3:.2f} GB")
logger.info(f"Freed GPU memory: {(initial_memory - final_memory) / 1024**3:.2f} GB")

2025-01-06 10:00:04,970 - INFO - Initial GPU memory allocated: 19.37 GB
2025-01-06 10:00:05,128 - INFO - Final GPU memory allocated: 19.37 GB
2025-01-06 10:00:05,130 - INFO - Freed GPU memory: 0.00 GB


In [17]:
# Setup training arguments and data collator
args = Args()
training_args = TrainingArguments(
    output_dir=args.output_dir,
    per_device_train_batch_size=args.per_device_train_batch_size,
    per_device_eval_batch_size=args.per_device_eval_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    learning_rate=args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    save_steps=args.save_steps,
    evaluation_strategy=args.evaluation_strategy if args.validation_split else "no",
    eval_steps=args.eval_steps if args.validation_split else None,
    logging_steps=args.logging_steps,
    logging_dir=args.logging_dir,
    report_to=args.report_to,  # Enable TensorBoard logging
    remove_unused_columns=args.remove_unused_columns,
    logging_first_step=args.logging_first_step,
    gradient_checkpointing=args.gradient_checkpointing,
    save_safetensors=True
)

data_collator = MergerDataCollator(
    tokenizer,
    pad_to_multiple_of=8,
    return_tensors="pt"
)

# Initialize and start training
trainer = MergerTrainer(
    model=merger,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=None,
    data_collator=data_collator,
)



In [None]:
trainer.train()

Step,Training Loss
1,1.3537
10,0.7213
20,0.4422
30,0.3734
40,0.3411
50,0.3242
60,0.3091
70,0.2931
80,0.2838


In [11]:
inputs = data_collator([x for x in tokenized_dataset.select(range(2))])
print(inputs)

2025-01-06 09:48:24,520 - INFO - >>> Collator output keys: dict_keys(['input_ids', 'attention_mask', 'labels', 'data_source'])


{'input_ids': tensor([[128000, 128006,   9125,  ..., 128001, 128001, 128001],
        [128000, 128006,   9125,  ..., 128001, 128001, 128001]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[128000, 128006,   9125,  ...,   -100,   -100,   -100],
        [128000, 128006,   9125,  ...,   -100,   -100,   -100]]), 'data_source': tensor([1, 1])}


In [12]:
inputs = {k: v.to(device=merger.device) for k, v in inputs.items()}
inputs

{'input_ids': tensor([[128000, 128006,   9125,  ..., 128001, 128001, 128001],
         [128000, 128006,   9125,  ..., 128001, 128001, 128001]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0'),
 'labels': tensor([[128000, 128006,   9125,  ...,   -100,   -100,   -100],
         [128000, 128006,   9125,  ...,   -100,   -100,   -100]],
        device='cuda:0'),
 'data_source': tensor([1, 1], device='cuda:0')}

In [13]:
labels = inputs["labels"].clone()
effective_idxs = (labels != -100).float().unsqueeze(dim=-1)

In [59]:
trainer.compute_loss(merger, inputs)

tensor(0.0024, device='cuda:0', grad_fn=<MeanBackward0>)

In [14]:
outputs = merger(**inputs)
logits_merged = outputs["merger_outputs"].logits
logits_components = [x.logits for x in outputs["components_outputs"]]

In [15]:
logits_merged

tensor([[[ 5.3125,  7.7188, 13.1875,  ..., -6.2188, -6.2188, -6.2188],
         [ 1.0156,  0.2188, -0.8008,  ...,  3.3438,  3.3438,  3.3438],
         [ 3.1094,  4.7812,  3.5781,  ..., -1.0781, -1.0781, -1.0781],
         ...,
         [ 0.4160,  1.6172,  2.3750,  ..., -2.0312, -2.0312, -2.0312],
         [ 0.6211,  1.8828,  2.5156,  ..., -1.9609, -1.9609, -1.9609],
         [ 0.7734,  1.8828,  2.5781,  ..., -1.9141, -1.9141, -1.9141]],

        [[ 5.3125,  7.7188, 13.1875,  ..., -6.2188, -6.2188, -6.2188],
         [ 1.0156,  0.2188, -0.8008,  ...,  3.3438,  3.3438,  3.3438],
         [ 3.1094,  4.7812,  3.5781,  ..., -1.0781, -1.0781, -1.0781],
         ...,
         [ 0.9570,  2.0625,  3.1562,  ..., -1.1562, -1.1562, -1.1562],
         [ 0.9492,  2.0938,  3.1406,  ..., -1.2109, -1.2109, -1.2109],
         [ 1.0156,  2.1406,  3.1719,  ..., -1.2344, -1.2344, -1.2344]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)

In [16]:
logits_components[0]

tensor([[[ 5.2812,  7.6875, 13.1250,  ..., -6.1562, -6.1562, -6.1562],
         [ 0.8477,  0.1689, -0.8867,  ...,  3.3594,  3.3594,  3.3594],
         [ 3.0781,  4.7812,  3.6094,  ..., -1.0234, -1.0234, -1.0234],
         ...,
         [ 3.3594,  3.5156,  5.2812,  ..., -3.0156, -3.0156, -3.0156],
         [ 3.3594,  3.5156,  5.2812,  ..., -3.0156, -3.0156, -3.0156],
         [ 3.3594,  3.5156,  5.2812,  ..., -3.0156, -3.0156, -3.0156]],

        [[ 5.2812,  7.6875, 13.1250,  ..., -6.1562, -6.1562, -6.1562],
         [ 0.8477,  0.1689, -0.8867,  ...,  3.3594,  3.3594,  3.3594],
         [ 3.0781,  4.7812,  3.6094,  ..., -1.0234, -1.0234, -1.0234],
         ...,
         [ 3.3594,  3.5156,  5.2812,  ..., -3.0156, -3.0156, -3.0156],
         [ 3.3594,  3.5156,  5.2812,  ..., -3.0156, -3.0156, -3.0156],
         [ 3.3594,  3.5156,  5.2812,  ..., -3.0156, -3.0156, -3.0156]]],
       device='cuda:0', dtype=torch.bfloat16)

In [29]:
trainer.compute_loss??

[0;31mSignature:[0m
[0mtrainer[0m[0;34m.[0m[0mcompute_loss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minputs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
How the loss is computed by Trainer. By default, all models return the loss in the first element.

Subclass and override for custom behavior.
[0;31mSource:[0m   
    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mlabels[0m [0;34m=[0m [0minputs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34m"

In [31]:
def masked_kl_div(logits_a, logits_b, mask, temperature=1.0):
    logits_a = logits_a.view(-1, logits_a.size(-1)) / temperature
    logits_b = logits_b.view(-1, logits_b.size(-1)) / temperature
    mask = mask.view(-1)

    assert mask.size(0) == logits_a.size(0)

    log_probs_a = nn.functional.log_softmax(logits_a, dim=-1)
    log_probs_b = nn.functional.log_softmax(logits_b, dim=-1)

    div = log_probs_a.exp() * (log_probs_a - log_probs_b)
    div = div.sum(-1)

    div = (div * mask).sum() / mask.sum() * (temperature ** 2)
    return div

In [35]:
masked_kl_div(logits_components[1], logits_components[0], effective_idxs)

tensor(1.0889, device='cuda:0')

In [62]:
temperature = 1.0
kl_fct = nn.KLDivLoss(reduction="none")
diff = (
    kl_fct(
        F.log_softmax(logits_components[1] / temperature, dim=-1),
        F.softmax(logits_components[0] / temperature, dim=-1)
    )
    * (temperature) ** 2
)

In [63]:
inputs

{'input_ids': tensor([[128000, 128006,   9125,  ..., 128001, 128001, 128001],
         [128000, 128006,   9125,  ..., 128001, 128001, 128001]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}

In [64]:
# Calculate final loss
loss = (diff * effective_idxs).sum(dim=-1)
loss = (loss / effective_idxs.sum(dim=1)).mean()

In [68]:
logits_components[1]

tensor([[[ 5.1562,  7.2812, 13.3125,  ..., -6.1250, -6.1250, -6.1250],
         [ 2.1406,  0.4219,  2.6562,  ...,  0.3164,  0.3164,  0.3164],
         [ 4.0938,  1.6641,  2.5156,  ..., -2.3594, -2.3594, -2.3594],
         ...,
         [ 3.5312,  4.6250,  5.9375,  ..., -2.5156, -2.5156, -2.5156],
         [ 3.5312,  4.6250,  5.9375,  ..., -2.5156, -2.5156, -2.5156],
         [ 3.5312,  4.6250,  5.9375,  ..., -2.5156, -2.5156, -2.5156]],

        [[ 5.1562,  7.2812, 13.3125,  ..., -6.1250, -6.1250, -6.1250],
         [ 2.1406,  0.4219,  2.6562,  ...,  0.3164,  0.3164,  0.3164],
         [ 4.0938,  1.6641,  2.5156,  ..., -2.3594, -2.3594, -2.3594],
         ...,
         [ 3.5312,  4.6250,  5.9375,  ..., -2.5156, -2.5156, -2.5156],
         [ 3.5312,  4.6250,  5.9375,  ..., -2.5156, -2.5156, -2.5156],
         [ 3.5312,  4.6250,  5.9375,  ..., -2.5156, -2.5156, -2.5156]]],
       device='cuda:0', dtype=torch.bfloat16)

In [73]:
(logits_merged - logits_components[1]).pow(2).mean()

tensor(1.8594, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

In [78]:
kl_fct(
    F.log_softmax(logits_components[0] / temperature, dim=-1),
    F.softmax(logits_components[1] / temperature, dim=-1)
).sum(-1).mean()

OutOfMemoryError: CUDA out of memory. Tried to allocate 274.00 MiB. GPU 0 has a total capacity of 79.21 GiB of which 81.25 MiB is free. Process 3192013 has 79.11 GiB memory in use. Of the allocated memory 75.54 GiB is allocated by PyTorch, and 2.90 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [38]:
trainer.compute_loss??

[0;31mSignature:[0m
[0mtrainer[0m[0;34m.[0m[0mcompute_loss[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minputs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
How the loss is computed by Trainer. By default, all models return the loss in the first element.

Subclass and override for custom behavior.
[0;31mSource:[0m   
    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0mnum_items_in_batch[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0mlabels[0m [0;34m=[0m [0minputs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34m"

In [55]:
trainer.model.merger.model.layers[8].mlp.up_proj.get_constrained_masks()

{'weight_masks': [Parameter containing:
  tensor([1.0078, 0.9219, 1.2891,  ..., 0.9766, 1.0781, 1.4062], device='cuda:0',
         dtype=torch.bfloat16, requires_grad=True),
  Parameter containing:
  tensor([0.5273, 0.2832, 0.5859,  ..., 0.1475, 0.4023, 0.7812], device='cuda:0',
         dtype=torch.bfloat16, requires_grad=True)],
 'bias_masks': [None, None]}

In [23]:
train_dataset[0]['messages']

[{'content': 'Provide a concise, objective summary of the input text in up to three sentences, focusing on key actions and intentions without using second or third person pronouns.',
  'role': 'system'},
 {'content': "By . Ted Thornhill . A mom may have saved her teenage son's life by spying on his Facebook page, as she found death threats on it and alerted police. The concerned parent, from Salt Lake City, called the authorities when she found threats to shoot her son - who attends West High School – had been posted on his profile page. The threats were allegedly made by two male teenagers, 16 and 17, who police arrested when they were found waiting in a car near the school on Friday. Potential life-saver: A mother of a West High School pupil alerted police after she saw threats to her son's life had been made on his Facebook page . Police said they found a gun, loaded magazine, ammunition, cash, marijuana and a bong inside the car. Salt Lake police detective Greg Wilking told Deseret

In [26]:
idx = 4
system = train_dataset[idx]['messages'][0]['content']
prompt = train_dataset[idx]['messages'][1]['content']
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

In [27]:
answer = generate(text, trainer.model.merger, tokenizer)

Emily met Tom at the parent-teacher conference and wants to discuss teaching methods and offers to give a guest lecture on the history of Pine Grove School. She also invites Tom to meet for coffee to talk more. Emily is working on a book about the history of education in colonial America.<|end_of_text|>


In [28]:
rewrite_train = load_dataset("HuggingFaceTB/smoltalk", "smol-rewrite", split="train")

In [33]:
idx = 300
system = rewrite_train[idx]['messages'][0]['content']
prompt = rewrite_train[idx]['messages'][1]['content']
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
print(text)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You're an AI assistant for text re-writing. Rewrite the input text to make it more concise while preserving its core meaning.<|eot_id|><|start_header_id|>user<|end_header_id|>

Hey Michael,

I hope you're doing well! I wanted to touch base with you regarding the progress on our nutrition curriculum project. I've been working on the lesson plans for grades 3-5 and have made some great strides. I'd love to get your feedback on what I've put together so far.

Also, I wanted to share a new recipe I tried out this weekend - a quinoa and black bean salad that was a hit with my family. I thought you might enjoy it too, given our mutual love for healthy cooking. I'll attach the recipe below.

Finally, I've been giving some thought to pursuing a Master's degree in Nutrition Education. I know you've been in the field for a while now, and I was hoping to get your advice on the best path forward. If you have any insights or recommendatio

In [39]:
answer = generate(text, trainer.model.merger, tokenizer)

Emily is sharing updates on the nutrition curriculum project and seeking feedback on the lesson plans for grades 3-5. She also shares a new quinoa and black bean salad recipe and asks for advice on pursuing a Master's degree in Nutrition Education. Emily looks forward to catching up soon.<|end_of_text|>


In [52]:
trainer.model.merger.save_pretrained("./trained_masks", safe_serialization=False)