Instead of using the complex TRL we code it from scratch, using lighting

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

# ML
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange, reduce, repeat
from jaxtyping import Float, Int, Bool

# Numeric
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

# Local
from reprpo.helpers.torch import clear_mem
from reprpo.gen import generation_test


In [3]:
torch.set_float32_matmul_precision("medium")

In [4]:
from dataclasses import dataclass

@dataclass
class TrainingArguments:
    # model
    model_name: str = "microsoft/Phi-3-mini-4k-instruct"
    use_bnb: bool = True # this doesn't seem to be able to backprop when using baukit
    use_gradient_checkpointing: bool = False
    use_inputs: bool = True

    # train
    n_epochs: int = 1
    batch_size: int = 16
    lr: float = 1e-4
    weight_decay: float = 0.0

    # dataset
    n_samples: int = 1000
    max_length: int = 128
    max_prompt_length: int=64


args = TrainingArguments()
args

TrainingArguments(model_name='microsoft/Phi-3-mini-4k-instruct', use_bnb=True, use_gradient_checkpointing=False, use_inputs=True, n_epochs=1, batch_size=16, lr=0.0001, weight_decay=0.0, n_samples=1000, max_length=128, max_prompt_length=64)

## Load model

In [5]:
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

from reprpo.models.load import load_model, print_trainable_parameters

args

model, tokenizer = load_model(args.model_name, bnb=args.use_bnb )

if args.use_gradient_checkpointing:
    model.enable_input_require_grads()
# peft_module_casting_to_bf16(model)
model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': args.use_gradient_checkpointing})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Load adapter

In [6]:
from peft.tuners import BOFTConfig, OFTConfig, LoraConfig, IA3Config
adapter_name='ReprPO'
peft_config = LoraConfig(
    lora_alpha=16, 
    r=16,
    # use_rslora=True,
    # use_dora=True,
    task_type="CAUSAL_LM",
    target_modules=[
        # FIXME: I'm not sure we can do LORA on the layer we are targeting?
        "qkv_proj", "gate_up_proj", # in
        "down_proj",  "o_proj", # out
                    ], # PHI3
)
model = get_peft_model(model, peft_config, adapter_name=adapter_name)
print_trainable_parameters(model)

trainable params: 25165824 || all params: 2033980416 || trainable%: 1.2372697299362787


## Load data

In [7]:
from datasets import load_dataset

dataset = load_dataset('Atsunori/HelpSteer2-DPO').map(lambda x: {
    'prompt': x['prompt']+ ' '})
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')

# QC one row
r = dataset2['train'][0]
print(r['prompt'])
print('===')
print(r['chosen'])
print('---')
print(r['rejected'])

c# 
===
C# (pronounced "C sharp") is a modern, object-oriented programming language developed by Microsoft. It is widely used for building various types of applications, including web applications, desktop applications, mobile applications, and games. C# is similar to other programming languages such as Java and C++, and it is known for its simplicity and ease of use. C# is a powerful language that provides a rich set of libraries and frameworks that make it easy to build robust and scalable applications.

Here is a brief overview of some key features of C#:

1. Object-oriented: C# is an object-oriented language, which means it uses the concept of objects to represent real-world entities and their behavior.

2. Cross-platform: C# can be used to build applications for multiple platforms, including Windows, macOS, and Linux.

3. Strongly typed: C# is a strongly typed language, which means that variables must be declared with a specific type, and their type cannot be changed at runtime.



### Data Loader

We use huggingface datasets, which are pretokenized. So that we can stack

In [8]:
def tokenize_row(feature, tokenizer, args: TrainingArguments):
    """
    Tokenize a single row from a DPO specific dataset.

    see https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L784
    """
    batch = {}
    batch["chosen"] = tokenizer(feature["chosen"])["input_ids"]
    batch["rejected"] = tokenizer(feature["rejected"])["input_ids"]
    batch["prompt"] = tokenizer(feature["prompt"])["input_ids"]
    return batch

In [9]:
dataset3 = dataset2.map(lambda x: tokenize_row(x, tokenizer, args), batched=True, writer_batch_size=10)
dataset3['train'][0].keys()

dict_keys(['prompt', 'chosen', 'rejected'])

In [10]:
# # reuse, some TRL classes for now
# from trl.trainer.dpo_trainer import DPOTrainer, DPODataCollatorWithPadding
# from trl.trainer.utils import pad

In [11]:
"""
see 
- https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L349
- https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb
"""

# TODO just make this always max length?

@dataclass
class DPODataCollatorWithPadding:
    tokenizer: Any
    max_length: int
    pad_token_id: int
    mask_prompt_tokens:bool=True
    max_prompt_length: int=64
    device: str = "cpu"

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        """use to pad the sequences in each batch to an equal length so that we can assemble them in batches

        takes in batch of tokenized
        """
        # Initialize lists to hold batch data
        batch_data = {
            "prompt": [],
            "chosen": [],
            "rejected": [],
            "rejected_mask": [],
            "chosen_mask": []

        }

        # Determine the longest sequence to set a common padding length
        max_length_common = 0
        if batch:
            for key in ["chosen", "rejected"]:
                current_max = max(len(item[key])+1 for item in batch) + self.max_prompt_length
                max_length_common = max(max_length_common, current_max)


        # Process each item in the batch
        for item in batch:
            prompt = item["prompt"]
            batch_data["prompt"].append(torch.tensor(item["prompt"]))

            prompt = prompt[:self.max_prompt_length]

            for key in ["chosen", "rejected"]:
                # Adjust padding according to the common maximum length
                sequence = prompt + item[key]
                padded = sequence + [self.pad_token_id] * (max_length_common - len(sequence))
                mask = torch.ones(len(padded)).bool()

                # Set mask for all padding tokens to False
                mask[len(sequence):] = False

                # # Set mask for all input tokens to False
                # # +2 sets the 2 newline ("\n") tokens before "### Response" to False
                # if self.mask_prompt_tokens:
                #     mask[:prompt.shape[0]+2] = False

                batch_data[key].append(torch.tensor(padded))
                batch_data[f"{key}_mask"].append(mask)

        # Final processing
        for key in ["chosen", "rejected", "chosen_mask", "rejected_mask"]:
            # Stack all sequences into a tensor for the given key
            tensor_stack = torch.stack(batch_data[key])

            # Optionally truncate to maximum sequence length
            if self.max_length is not None:
                tensor_stack = tensor_stack[:, :self.max_length]

            # Move to the specified device
            batch_data[key] = tensor_stack.to(self.device)

        return batch_data


In [12]:
custom_collate_fn = DPODataCollatorWithPadding(pad_token_id=tokenizer.pad_token_id, 
                                                  tokenizer=tokenizer,
                                                  max_length=args.max_length,
                                                  mask_prompt_tokens=True,
                                                  max_prompt_length=args.max_prompt_length,
                                                  #label_pad_token_id=-100
                                                  )



In [13]:
from torch.utils.data import DataLoader

ds = dataset3
dl_train = DataLoader(ds['train'], batch_size=args.batch_size, collate_fn=custom_collate_fn)

dl_val = DataLoader(ds['validation'], batch_size=args.batch_size, collate_fn=custom_collate_fn)

# QC
batch = next(iter(dl_train))
batch.keys()

dict_keys(['prompt', 'chosen', 'rejected', 'rejected_mask', 'chosen_mask'])

## Trainer

In [14]:
import torch.nn.functional as F

def compute_dpo_loss(
      model_chosen_logprobs,
      model_rejected_logprobs,
      reference_chosen_logprobs,
      reference_rejected_logprobs,
      beta=0.1,
    ):
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logprobs: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logprobs: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
        label_smoothing: conservativeness for DPO loss.

    Returns:
        A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).
    """

    model_logratios = model_chosen_logprobs - model_rejected_logprobs
    reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs
    logits = model_logratios - reference_logratios

    # DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
    losses = -F.logsigmoid(beta * logits)

    # Optional values to track progress during training
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    # .mean() to average over the samples in the batch
    return losses.mean(), dict(
        chosen_rewards=chosen_rewards.mean(), rejected_reward=rejected_rewards.mean()
    )


In [15]:
     

def compute_logprobs(logits, labels, selection_mask=None):
    """
    Compute log probabilities.

    Args:
      logits: Tensor of shape (batch_size, num_tokens, vocab_size)
      labels: Tensor of shape (batch_size, num_tokens)
      selection_mask: Tensor for shape (batch_size, num_tokens)

    Returns:
      mean_log_prob: Mean log probability excluding padding tokens.
    """

    # Labels are the inputs shifted by one
    labels = labels[:, 1:].clone()

    # Truncate logits to match the labels num_tokens
    logits = logits[:, :-1, :]

    log_probs = F.log_softmax(logits, dim=-1)

    # Gather the log probabilities for the actual labels
    selected_log_probs = torch.gather(
        input=log_probs,
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    if selection_mask is not None:
        mask = selection_mask[:, 1:].clone()

        # Apply the mask to filter out padding tokens
        selected_log_probs = selected_log_probs * mask

        # Calculate the average log probability excluding padding tokens
        # This averages over the tokens, so the shape is (batch_size, num_tokens)
        avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)

        return avg_log_prob

    else:
        return selected_log_probs.mean(-1)



In [16]:
def compute_dpo_loss_batch(batch, model, beta=0.1):
    """Compute the DPO loss on an input batch"""

    model.eval()
    with torch.no_grad():
        with model.disable_adapter():
            ref_chosen_log_probas = compute_logprobs(
                logits=model(batch["chosen"]).logits,
                labels=batch["chosen"],
                selection_mask=batch["chosen_mask"]
            )
            ref_rejected_log_probas = compute_logprobs(
                logits=model(batch["rejected"]).logits,
                labels=batch["rejected"],
                selection_mask=batch["rejected_mask"]
            )
    
    model.train()
    # where policy_model(batch["chosen"]) are the logits
    # FIXME: need to tracedict, and deal with dict outputs?
    policy_chosen_log_probas = compute_logprobs(
        logits=model(batch["chosen"]).logits,
        labels=batch["chosen"],
        selection_mask=batch["chosen_mask"]
    )
    policy_rejected_log_probas = compute_logprobs(
        logits=model(batch["rejected"]).logits,
        labels=batch["rejected"],
        selection_mask=batch["rejected_mask"]
    )

    loss, info = compute_dpo_loss(
        model_chosen_logprobs=policy_chosen_log_probas,
        model_rejected_logprobs=policy_rejected_log_probas,
        reference_chosen_logprobs=ref_chosen_log_probas,
        reference_rejected_logprobs=ref_rejected_log_probas,
        beta=beta
    )
    return loss, info

In [17]:
# # QC
# loss, info = compute_dpo_loss_batch(batch, model)

- https://lightning.ai/docs/pytorch/latest/notebooks/lightning_examples/text-transformers.html
- https://gist.github.com/wassname/e29d02b5026a531e13912cf768e6fdc8

In [18]:
# from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision

In [19]:
from torch import optim
import lightning as pl
from torchmetrics.functional import accuracy
import bitsandbytes as bnb

class PL_MODEL(pl.LightningModule):
    def __init__(self, model, num_iterations, lr=3e-4, weight_decay=0, ):
        super().__init__()
        self._model = model
        self.save_hyperparameters(ignore=['model'])
    
    def forward(self, x):
        return self._model(x)

    def _shared_step(self, batch, batch_idx, phase='train'):        

        loss, info = compute_dpo_loss_batch(batch, model)

        self.log(f"{phase}/loss", loss, on_epoch=True, on_step=True, prog_bar=True, batch_size=args.batch_size)

        # TODO log info
        self.log_dict({
            f"{phase}/{k}": v for k, v in info.items()}, on_epoch=True, on_step=False, batch_size=args.batch_size),
        return loss
    
    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch, batch_idx):
        return self._shared_step(batch, batch_idx, phase='val')
    
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self._shared_step(batch, batch_idx, phase='test')
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self._shared_step(batch, batch_idx, phase='pred').float().cpu().detach()
    
    def configure_optimizers(self):
        """simple vanilla torch optim"""
        # https://lightning.ai/docs/fabric/stable/fundamentals/precision.html
        optimizer = bnb.optim.AdamW8bit(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        # optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        # https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, self.hparams.lr, total_steps=self.hparams.num_iterations, verbose=True,
        )
        lr_scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optimizer], [lr_scheduler]

In [20]:
max_batches = min(len(dl_train), 1000000)
pl_model = PL_MODEL(model,
                 weight_decay=args.weight_decay,
                lr=args.lr,
                num_iterations=max_batches*args.n_epochs
                )

model_name = type(model).__name__
max_batches = min(len(dl_train), 1000000)

In [21]:
timestamp = pd.Timestamp.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f"../outputs/{timestamp}/{model_name}"
Path(save_dir).mkdir(exist_ok=True, parents=True)
trainer = pl.Trainer(
        max_epochs=args.n_epochs,
#         limit_train_batches=max_batches,
#         limit_val_batches=max_batches//5,
        gradient_clip_val=20,
        precision="bf16",
        log_every_n_steps=1,
#         callbacks=[LearningRateMonitor(logging_interval='step')],
#         logger=[CSVLogger(name=model_name, save_dir=save_dir, flush_logs_every_n_steps=5),],
#         default_root_dir=save_dir,

        # fast_dev_run=True,
        # plugins=[BitsandbytesPrecision(mode="nf4", dtype=torch.bfloat16),],
    )

# train
trainer.fit(pl_model, dl_train, dl_val)

/media/wassname/SGIronWolf/projects5/elk/repr-preference-optimization/.venv/lib/python3.9/site-packages/lightning/fabric/connector.py:571: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name   | Type                 | Params | Mode 
--------------------------------------------------------
0 | _model | PeftModelForCausalLM | 2.0 B  | train
--------------------------------------------------------
25.2 M    Trainable param

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/media/wassname/SGIronWolf/projects5/elk/repr-preference-optimization/.venv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
/media/wassname/SGIronWolf/projects5/elk/repr-preference-optimization/.venv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


### Hist

In [23]:


# # test
# y_preds = trainer.predict(model, dataloaders=dl_test)
# y_pred = torch.concat(y_preds)[:, 0].numpy()
# y_pred

# Hist

def read_metrics_csv(metrics_file_path):
    df_hist = pd.read_csv(metrics_file_path)
    df_hist["epoch"] = df_hist["epoch"].ffill()
    df_histe = df_hist.set_index("epoch").groupby("epoch").last().ffill().bfill()
    return df_histe

def plot_hist(df_hist, allowlist=None, logy=False):
    """plot groups of suffixes together"""
    suffixes = list(set([c.split('/')[-1] for c in df_hist.columns if '/' in c]))
    for suffix in suffixes:
        if allowlist and suffix not in allowlist: continue
        df_hist[[c for c in df_hist.columns if c.endswith(suffix) and '/' in c]].plot(title=suffix, style='.', logy=logy)
        plt.title(suffix)   
        plt.show()

In [24]:
df_hist = read_metrics_csv(trainer.logger.experiment.metrics_file_path).bfill().ffill()
plot_hist(df_hist, ['loss', 'acc', 'auroc'])
display(df_hist)

AttributeError: 'SummaryWriter' object has no attribute 'metrics_file_path'