Instead of using the complex TRL we code it from scratch, using lighting

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb

In [None]:
%reload_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

# ML
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange, reduce, repeat
from jaxtyping import Float, Int, Bool
from torch.utils.data import DataLoader

# Numeric
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

# lightning
import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.loggers.csv_logs import CSVLogger


In [3]:
# Local
from reprpo.helpers.torch import clear_mem
from reprpo.gen import generation_test
import reprpo.silence
from reprpo.helpers.lightning_hist import read_metrics_csv, plot_hist

from reprpo.data.collate import DPODataCollatorWithPadding
from reprpo.train.dpo import compute_dpo_loss_batch, PL_DPO_MODEL

In [4]:
torch.set_float32_matmul_precision("medium")

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from reprpo.helpers.wandb import init_wandb

nb_name = init_wandb(__vsc_ipynb_file__)

In [5]:
from reprpo.train.lightning import TrainingArguments
from reprpo.train.dpo import DPOTrainingArguments


args = DPOTrainingArguments(lr=1e-5)
args

DPOTrainingArguments(model_name='google/gemma-2-2b', load_in_4bit=True, use_gradient_checkpointing=False, n_epochs=1, batch_size=16, lr=1e-05, weight_decay=0.0, n_samples=1800, max_length=128, max_prompt_length=64)

## Load model

In [6]:
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

from reprpo.models.load import load_model, print_trainable_parameters

# args

model, tokenizer = load_model(args.model_name, load_in_4bit=args.load_in_4bit, attn_implementation='eager')

if args.use_gradient_checkpointing:
    model.enable_input_require_grads()
# # peft_module_casting_to_bf16(model)
# model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': args.use_gradient_checkpointing})
# # model

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
# this is from trl https://github.com/huggingface/trl/blob/cbcaa46cd3c02c0e7f724b764c5848ae73796de7/trl/trainer/utils.py#L747
# not sure if it's needed but `prepare_model_for_kbit_training` doesn't seem to do this ,despite claiming to
def peft_module_casting_to_bf16(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.LayerNorm) or "norm" in name:
            module = module.to(torch.float32)
            print(f"casting {name} to bf16")

peft_module_casting_to_bf16(model)

casting model.layers.0.input_layernorm to bf16
casting model.layers.0.post_attention_layernorm to bf16
casting model.layers.0.pre_feedforward_layernorm to bf16
casting model.layers.0.post_feedforward_layernorm to bf16
casting model.layers.1.input_layernorm to bf16
casting model.layers.1.post_attention_layernorm to bf16
casting model.layers.1.pre_feedforward_layernorm to bf16
casting model.layers.1.post_feedforward_layernorm to bf16
casting model.layers.2.input_layernorm to bf16
casting model.layers.2.post_attention_layernorm to bf16
casting model.layers.2.pre_feedforward_layernorm to bf16
casting model.layers.2.post_feedforward_layernorm to bf16
casting model.layers.3.input_layernorm to bf16
casting model.layers.3.post_attention_layernorm to bf16
casting model.layers.3.pre_feedforward_layernorm to bf16
casting model.layers.3.post_feedforward_layernorm to bf16
casting model.layers.4.input_layernorm to bf16
casting model.layers.4.post_attention_layernorm to bf16
casting model.layers.4.pr

In [8]:
# model

### Load adapter

In [9]:
from peft.tuners import BOFTConfig, OFTConfig, LoraConfig, IA3Config
adapter_name='ReprPO'
peft_config = LoraConfig(
    lora_alpha=16, 
    r=16,
    # use_rslora=True,
    # use_dora=True,
    task_type="CAUSAL_LM",
    target_modules= ["q_proj", "v_proj"], # gemma, llama
    # target_modules=[
        # FIXME: I'm not sure we can do LORA on the layer we are targeting?
        # "qkv_proj", "gate_up_proj", # in
        # "down_proj",  "o_proj", # out
        #             ], # PHI3
)
model = get_peft_model(model, peft_config, adapter_name=adapter_name)
print_trainable_parameters(model)
model

trainable params: 3194880 || all params: 1605398784 || trainable%: 0.19900849756716896


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma2ForCausalLM(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 2304, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2304, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (ReprPO): Identity()
                )
                (lora_A): ModuleDict(
                  (ReprPO): Linear(in_features=2304, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (ReprPO): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): Linear4

## Load data

In [10]:
from datasets import load_dataset

# dataset = load_dataset('Atsunori/HelpSteer2-DPO').map(lambda x: {
#     'prompt': x['prompt']+ ' '})
# dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')

dataset2 = load_dataset("wassname/genie_dpo", name="code_easy")
print(dataset2)

# QC one row
r = dataset2['train'][0]
print(r['prompt'])
print('===')
print(r['chosen'])
print('---')
print(r['rejected'])

### Data Loader

We use huggingface datasets, which are pretokenized. So that we can stack

In [None]:
def tokenize_row(feature, tokenizer, args: TrainingArguments):
    """
    Tokenize a single row from a DPO specific dataset.

    see https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L784
    """
    batch = {}
    batch["chosen"] = tokenizer(feature["chosen"])["input_ids"]
    batch["rejected"] = tokenizer(feature["rejected"])["input_ids"]
    batch["prompt"] = tokenizer(feature["prompt"])["input_ids"]
    return batch

In [None]:
dataset3 = dataset2.map(lambda x: tokenize_row(x, tokenizer, args), batched=True, writer_batch_size=10)
dataset3['train'][0].keys()

In [None]:
custom_collate_fn = DPODataCollatorWithPadding(pad_token_id=tokenizer.pad_token_id, 
                                                  tokenizer=tokenizer,
                                                  max_length=args.max_length,
                                                  mask_prompt_tokens=True,
                                                  max_prompt_length=args.max_prompt_length,
                                                  #label_pad_token_id=-100
                                                  )



In [None]:


ds = dataset3
dl_train = DataLoader(ds['train'], batch_size=args.batch_size, collate_fn=custom_collate_fn)

dl_val = DataLoader(ds['test'], batch_size=args.batch_size, collate_fn=custom_collate_fn)

# QC
batch = next(iter(dl_train))
batch.keys()

## Trainer

In [None]:
# # QC
# loss, info = compute_dpo_loss_batch(batch, model)

- https://lightning.ai/docs/pytorch/latest/notebooks/lightning_examples/text-transformers.html
- https://gist.github.com/wassname/e29d02b5026a531e13912cf768e6fdc8

In [None]:
max_steps = args.n_samples // args.batch_size

In [None]:
ideal_batch_size = max(16, args.batch_size)
accumulate_grad_batches = np.ceil(ideal_batch_size/args.batch_size).astype(int)
accumulate_grad_batches, args.batch_size*accumulate_grad_batches

In [None]:
from lightning.pytorch.callbacks import LearningRateMonitor
from reprpo.train.lightning import GenCallback

In [None]:
pl_model = PL_DPO_MODEL(model,
                 weight_decay=args.weight_decay,
                lr=args.lr,
                num_iterations=max_steps,
                batch_size=args.batch_size,
                )


In [None]:
# from reprpo.helpers.lightning_save import AdapterModelCheckpoint

# checkpoint_callback = AdapterModelCheckpoint(
#     verbose=True,
# )

In [None]:
from reprpo.helpers.lightning_existing_bnb import ExistingBitsandbytesPrecision

from accelerate.utils import CustomDtype
precision = ExistingBitsandbytesPrecision(
    # dtype=torch.bfloat16,
    dtype=CustomDtype.INT4,
    # dtype=torch.int8,
    default_dtype=torch.bfloat16,
)

In [None]:
timestamp = pd.Timestamp.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f"../outputs/{timestamp}/{nb_name}"
Path(save_dir).mkdir(exist_ok=True, parents=True)
trainer = pl.Trainer(
        max_steps=max_steps,
#         limit_val_batches=max_batches//5,
        gradient_clip_val=20,

        # accelerator='gpu',
        devices=1, 
        plugins=precision,
        
        # https://lightning.ai/docs/pytorch/stable/common/trainer.html
        # precision="bf16-true", # "32-true" "transformer-enginer
        log_every_n_steps=1,
        accumulate_grad_batches=accumulate_grad_batches,
        callbacks=[
            LearningRateMonitor(logging_interval='step'),
            GenCallback(every=max_steps//20),
            # checkpoint_callback
        ],
        logger=[
            CSVLogger(name=nb_name, save_dir=save_dir, flush_logs_every_n_steps=5),
            WandbLogger(name=nb_name, save_dir=save_dir),
        ],
        default_root_dir=save_dir,

        # too large, we will just save adapter
        enable_checkpointing=False, 

        # fast_dev_run=True,
    )

# train
trainer.fit(pl_model, dl_train, dl_val)

In [None]:
torch.bfloat16
torch.int8

In [None]:
# save as regular adapter only

model.save_pretrained(
    save_dir+'-adapter',
)

In [None]:
model.config

### Hist

In [None]:
plt.style.use('ggplot')
import matplotlib
matplotlib.rcParams['figure.figsize'] = (6, 2)


df_hist = read_metrics_csv(trainer.logger.experiment.metrics_file_path).bfill().ffill()
plot_hist(df_hist, ['.*/loss_step', '.*/acc.*', '.*/auroc.*', '.*/.*reward_step'])
display(df_hist)

## Gen

In [None]:
# model.cuda()
generation_test(model, tokenizer, s="Q1: (30 words): Which Science Fiction Utopia is preferable and why? [The Polity, The Culture, Permutation City, 2 more]', ", max_new_tokens=64)

In [None]:
# model.prepare_inputs_for_generation??
model.base_model.prepare_inputs_for_generation??

In [None]:
model.dtype

In [None]:
from reprpo.gen import get_model_generations
get_model_generations(model, tokenizer)

## Eval

In [None]:
from reprpo.helpers.shypothesis import shypothesis
from reprpo.evaluate import evaluate_adapters
from open_pref_eval.evaluation import evaluate_model
from open_pref_eval.plot.radar import radar_plot

N = 64
datasets = [
            
            load_dataset('wassname/truthful_qa_dpo', split=f'validation[:{N}]', keep_in_memory=False),

            load_dataset('wassname/mmlu_dpo', name='elementary_mathematics', split=f'test[:{N}]', keep_in_memory=False), 

            # load_dataset('wassname/ethics_expression_dpo', name='commonsense', split=f'test[:{N}]', keep_in_memory=False),
            load_dataset('wassname/ethics_expression_dpo', name='utilitarianism', split=f'test[:{N}]', keep_in_memory=False),
            load_dataset('wassname/ethics_expression_dpo', name='justice', split=f'test[:{N}]', keep_in_memory=False),
            # load_dataset('wassname/ethics_expression_dpo', name='deontology', split=f'test[:{N}]', keep_in_memory=False),      
            load_dataset('wassname/genie_dpo', name='us_history_fiction', split=f'test[:{N}]', keep_in_memory=False),      
            load_dataset('wassname/genie_dpo', name='code_hard', split=f'test[:{N}]', keep_in_memory=False),      
            ]


clear_mem()
res, df_res2 = evaluate_model(model=model, 
                              tokenizer=tokenizer, 
                              datasets=datasets,
                                 batch_size=2,
                                 bf16=True,
                                 bf16_full_eval=True, 
                                 torch_empty_cache_steps=100,)
# radar_plot(res)
res

In [None]:
from open_pref_eval.plot.radar import radar_plot
df_res = df_res2.groupby(['dataset', 'adapter'], dropna=False)['correct'].mean().unstack()
radar_plot(df_res)
df_res

In [None]:
# print acc for journal
c  = df_res2.groupby(['adapter', 'dataset']).count().min().min()
print(f"⭐ run={nb_name}, N={c}")
print()
print(res[::-1].T[::-1].T.round(3).to_markdown()
      )
print()
print('args =', args)         

In [None]:
print('did acc improve')
acc_pi = res[adapter_name]['help_steer2-dpo'].item()
acc_ref = res['base']['help_steer2-dpo'].item()
shypothesis('acc_pi>acc_ref', locals())


acc_pi_ood = res[adapter_name]['truthful_qa_binary'].item()
acc_ref_ood = res['base']['truthful_qa_binary'].item()
shypothesis('acc_pi_ood>acc_ref_ood', locals());

In [None]:
print('did coherence improve?, (measured by mean prob per token) higher is better')
r = df_res2.groupby(['adapter', 'dataset'], dropna=False)['_chosen_logps'].mean().unstack()
r = np.exp(r)
display(r)

coherency_pi = float(r.T[adapter_name]['help_steer2-dpo'])
coherency_ref = float(r.T['base']['help_steer2-dpo'])
shypothesis('coherency_pi>coherency_ref', locals());

In [None]:

print('are we biased by the length of the string? Ideally no correlation')
a, b = df_res2['_l_chosen'], df_res2['_l_rejected']
x = (a-b)/(a+b)
plt.plot(x, df_res2['_logratio'], 'o')
plt.xlabel('chosen longer')
plt.ylabel('chosen more likely')

# Damn this is not ideal....
a = df_res2['_l_chosen'] / df_res2['_l_rejected']
b = df_res2['prob']

m = np.isfinite(a) & np.isfinite(b)
a = a[m]
b = b[m]
corr_length = np.corrcoef(a, b)[1,0]
print(f'{corr_length:.2f} (0 is ideal) correlation between length ratio and prob:')
shypothesis('corr_length<0.25', locals())


print(f'is the ds bised? {a.mean()/b.mean():.2f} (1 is ideal)')
a=df_res2['prob']>0
b=x>=0
acc_bad = (a==b).mean()
print(f'{acc_bad:.2%} (0.5 is ideal) how often does it accurately pick the longer one :( ')

shypothesis('acc_bad<0.75', locals())