Instead of using the complex TRL we code it from scratch, using lighting

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

# ML
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange, reduce, repeat
from jaxtyping import Float, Int, Bool

# Numeric
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

# Local
from reprpo.helpers.torch import clear_mem
from reprpo.gen import generation_test
import reprpo.silence


In [3]:
torch.set_float32_matmul_precision("medium")

In [4]:
from dataclasses import dataclass

@dataclass
class TrainingArguments:
    # model
    model_name: str = "microsoft/Phi-3-mini-4k-instruct"
    use_bnb: bool = True # this doesn't seem to be able to backprop when using baukit
    use_gradient_checkpointing: bool = False
    use_inputs: bool = True

    # train
    n_epochs: int = 1
    batch_size: int = 16
    lr: float = 1e-4
    weight_decay: float = 0.0

    # dataset
    n_samples: int = 1000
    max_length: int = 128
    max_prompt_length: int=64


args = TrainingArguments()
args

TrainingArguments(model_name='microsoft/Phi-3-mini-4k-instruct', use_bnb=True, use_gradient_checkpointing=False, use_inputs=True, n_epochs=1, batch_size=16, lr=0.0001, weight_decay=0.0, n_samples=1000, max_length=128, max_prompt_length=64)

## Load model

In [5]:
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

from reprpo.models.load import load_model, print_trainable_parameters

args

model, tokenizer = load_model(args.model_name, bnb=args.use_bnb )

if args.use_gradient_checkpointing:
    model.enable_input_require_grads()
# peft_module_casting_to_bf16(model)
model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': args.use_gradient_checkpointing})

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Load adapter

In [6]:
from peft.tuners import BOFTConfig, OFTConfig, LoraConfig, IA3Config
adapter_name='ReprPO'
peft_config = LoraConfig(
    lora_alpha=16, 
    r=16,
    # use_rslora=True,
    # use_dora=True,
    task_type="CAUSAL_LM",
    target_modules=[
        # FIXME: I'm not sure we can do LORA on the layer we are targeting?
        "qkv_proj", "gate_up_proj", # in
        "down_proj",  "o_proj", # out
                    ], # PHI3
)
model = get_peft_model(model, peft_config, adapter_name=adapter_name)
print_trainable_parameters(model)

trainable params: 25165824 || all params: 2033980416 || trainable%: 1.2372697299362787


## Load data

In [7]:
from datasets import load_dataset

dataset = load_dataset('Atsunori/HelpSteer2-DPO').map(lambda x: {
    'prompt': x['prompt']+ ' '})
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')

# QC one row
r = dataset2['train'][0]
print(r['prompt'])
print('===')
print(r['chosen'])
print('---')
print(r['rejected'])

c# 
===
C# (pronounced "C sharp") is a modern, object-oriented programming language developed by Microsoft. It is widely used for building various types of applications, including web applications, desktop applications, mobile applications, and games. C# is similar to other programming languages such as Java and C++, and it is known for its simplicity and ease of use. C# is a powerful language that provides a rich set of libraries and frameworks that make it easy to build robust and scalable applications.

Here is a brief overview of some key features of C#:

1. Object-oriented: C# is an object-oriented language, which means it uses the concept of objects to represent real-world entities and their behavior.

2. Cross-platform: C# can be used to build applications for multiple platforms, including Windows, macOS, and Linux.

3. Strongly typed: C# is a strongly typed language, which means that variables must be declared with a specific type, and their type cannot be changed at runtime.



### Data Loader

We use huggingface datasets, which are pretokenized. So that we can stack

In [8]:
def tokenize_row(feature, tokenizer, args: TrainingArguments):
    """
    Tokenize a single row from a DPO specific dataset.

    see https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L784
    """
    batch = {}
    batch["chosen"] = tokenizer(feature["chosen"])["input_ids"]
    batch["rejected"] = tokenizer(feature["rejected"])["input_ids"]
    batch["prompt"] = tokenizer(feature["prompt"])["input_ids"]
    return batch

In [9]:
dataset3 = dataset2.map(lambda x: tokenize_row(x, tokenizer, args), batched=True, writer_batch_size=10)
dataset3['train'][0].keys()

dict_keys(['prompt', 'chosen', 'rejected'])

In [10]:
# # reuse, some TRL classes for now
# from trl.trainer.dpo_trainer import DPOTrainer, DPODataCollatorWithPadding
# from trl.trainer.utils import pad

In [12]:
custom_collate_fn = DPODataCollatorWithPadding(pad_token_id=tokenizer.pad_token_id, 
                                                  tokenizer=tokenizer,
                                                  max_length=args.max_length,
                                                  mask_prompt_tokens=True,
                                                  max_prompt_length=args.max_prompt_length,
                                                  #label_pad_token_id=-100
                                                  )



In [13]:
from torch.utils.data import DataLoader

ds = dataset3
dl_train = DataLoader(ds['train'], batch_size=args.batch_size, collate_fn=custom_collate_fn)

dl_val = DataLoader(ds['validation'], batch_size=args.batch_size, collate_fn=custom_collate_fn)

# QC
batch = next(iter(dl_train))
batch.keys()

dict_keys(['prompt', 'chosen', 'rejected', 'rejected_mask', 'chosen_mask'])

## Trainer

In [None]:
from reprpo.data.collate import DPODataCollatorWithPadding
from reprpo.train.dpo import compute_dpo_loss_batch, PL_DPO_MODEL

In [17]:
# # QC
# loss, info = compute_dpo_loss_batch(batch, model)

- https://lightning.ai/docs/pytorch/latest/notebooks/lightning_examples/text-transformers.html
- https://gist.github.com/wassname/e29d02b5026a531e13912cf768e6fdc8

In [19]:
import lightning as pl

In [20]:
max_batches = min(len(dl_train), 1000000)
pl_model = PL_DPO_MODEL(model,
                 weight_decay=args.weight_decay,
                lr=args.lr,
                num_iterations=max_batches*args.n_epochs
                )

model_name = type(model).__name__
max_batches = min(len(dl_train), 1000000)

In [21]:
timestamp = pd.Timestamp.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f"../outputs/{timestamp}/{model_name}"
Path(save_dir).mkdir(exist_ok=True, parents=True)
trainer = pl.Trainer(
        max_epochs=args.n_epochs,
#         limit_train_batches=max_batches,
#         limit_val_batches=max_batches//5,
        gradient_clip_val=20,
        precision="bf16",
        log_every_n_steps=1,
#         callbacks=[LearningRateMonitor(logging_interval='step')],
#         logger=[CSVLogger(name=model_name, save_dir=save_dir, flush_logs_every_n_steps=5),],
#         default_root_dir=save_dir,

        # fast_dev_run=True,
        # plugins=[BitsandbytesPrecision(mode="nf4", dtype=torch.bfloat16),],
    )

# train
trainer.fit(pl_model, dl_train, dl_val)

/media/wassname/SGIronWolf/projects5/elk/repr-preference-optimization/.venv/lib/python3.9/site-packages/lightning/fabric/connector.py:571: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name   | Type                 | Params | Mode 
--------------------------------------------------------
0 | _model | PeftModelForCausalLM | 2.0 B  | train
--------------------------------------------------------
25.2 M    Trainable param

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/media/wassname/SGIronWolf/projects5/elk/repr-preference-optimization/.venv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
/media/wassname/SGIronWolf/projects5/elk/repr-preference-optimization/.venv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


### Hist

In [23]:
from reprpo.helpers.lightning_hist import read_metrics_csv, plot_hist

# # test
# y_preds = trainer.predict(model, dataloaders=dl_test)
# y_pred = torch.concat(y_preds)[:, 0].numpy()
# y_pred

# Hist



In [24]:
df_hist = read_metrics_csv(trainer.logger.experiment.metrics_file_path).bfill().ffill()
plot_hist(df_hist, ['loss', 'acc', 'auroc'])
display(df_hist)

AttributeError: 'SummaryWriter' object has no attribute 'metrics_file_path'