Trying to modify hf dpo to work with the repos hypothesis...

see
- https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth
- https://gist.github.com/alvarobartt/9898c33eb3e9c7108d9ed2330f12a708
- https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing#scrollTo=QtoqUw80QDV0

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_PROJECT"] = "repo-dpo" 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["WANDB_DISABLED"] = "true"



In [None]:
import wandb
os.environ['WANDB_NOTEBOOK_NAME'] = nb_name = os.path.basename(globals()['__vsc_ipynb_file__'])
# enable wandb service (experimental, https://github.com/wandb/client/blob/master/docs/dev/wandb-service-user.md)
# this hopefully fixes issues with multiprocessing
wandb.require(experiment='service')
# wandb.init()



In [None]:
import warnings
# warnings.simplefilter("ignore")
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings("ignore", ".*divide by zero.*")
warnings.filterwarnings("ignore", ".*torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly.*")
warnings.filterwarnings("ignore", ".*`do_sample` is set to.*")
warnings.filterwarnings("ignore", ".*None of the inputs have requires_grad=True. Gradients will be None*")


In [None]:
import torch
import numpy as np
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange

In [None]:
from contextlib import contextmanager
import pandas as pd
from matplotlib import pyplot as plt
from transformers.trainer import ProgressCallback
from transformers.utils.notebook import NotebookProgressCallback

from reprpo.helpers.adapters import set_adapter

In [None]:
torch.set_float32_matmul_precision("medium")
# torch.use_deterministic_algorithms(True)

In [None]:
max_prompt_length=128
# num_samples = 50 * 16 * 6
num_samples = 1500 * 1 * 5 # from circuit breaker * 3
max_length = 256
num_samples

## load the model

In [None]:
!pip install flash-attn --no-build-isolation --no-deps -qq

In [None]:
# FIXME: we are meant to SFT first, so that the preferences are in sample but 1) if this works it might not be needed, and 2) this can be added later, if it works
# for now we will use the instruct model, and try something it wasn't meant to do but it in sample 
model_name = "NousResearch/Meta-Llama-3-8B-Instruct"
use_gradient_checkpointing = True

## Big adapter
peft_config = LoraConfig(
    lora_alpha=16, 
    r=16,
    lora_dropout=0.0,
    use_rslora=True,
    use_dora=True,
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
from reprpo.models.load import load_model, print_trainable_parameters
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

model, tokenizer = load_model(model_name, )

if use_gradient_checkpointing:
    model.enable_input_require_grads()
# from trl.trainer.utils import peft_module_casting_to_bf16
# peft_module_casting_to_bf16(model)
adapter_name='ReprPO'
model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': use_gradient_checkpointing})
model = get_peft_model(model, peft_config, adapter_name=adapter_name)
print_trainable_parameters(model)

## Dataset

In [None]:
def sample(dataset, N):
    return (dataset
            .shuffle(42)
            .select(range(
            min(len(dataset), N)))
    )

In [None]:
# dataset = load_dataset('Columbia-NLP/DPO-HelpSteer')
dataset = load_dataset('Atsunori/HelpSteer2-DPO')
dataset['train'] = sample(dataset['train'], num_samples)
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')
dataset2

In [None]:
# def format_ds(row):
    
#     # WHY are we doing this? Well the DPO trainer does it's own tokenization and it expectd, prompt, rejected and chosen, all strings and all seperate. Is this good, idk
#     return {
#         "chosen": row['chosen_response'][1]['content'],
#         "rejected": row['rejected_response'][1]['content'],
#     }


# dataset2 = dataset.map(format_ds)


In [None]:
r = dataset2['train'][0]
print(r['prompt'])
print('===')
print(r['chosen'])
print('---')
print(r['rejected'])

## Eval TQA helpers

In [None]:
from reprpo.data.tqa import load_tqa
from torch.utils.data import DataLoader
import numpy as np


# dataset2_tqa, choice_ids = load_tqa(tokenizer, max_length, N=817)

How to measure TQA?
- [TruthfullLamama](https://github.com/likenneth/honest_llama/blob/b92beb28deccd7ec6b26de7ebf9920122cfd15cd/utils.py#L268) uses https://github.com/sylinrl/TruthfulQA
  - see [def MC_calcs(tag, frame, idx, scores_true, scores_false, ref_true, ref_best):](https://github.com/sylinrl/TruthfulQA/blob/fdd8ad1c0d00a478cf8b0bb41a3ad8378c16293b/truthfulqa/models.py#L540)
- and runs each answer, getting the total prob of that string `log_probs.sum()`

In [None]:
# from reprpo.eval.mc import eval_tqa
from reprpo.gen import generation_test

## Train

### Modified classes

- here we can defined the experimetns loss function

In [None]:
from reprpo.trainer import ReprPOTrainer, ReprPOConfig, wmean, coeffecient, top_k_mse, norm_smooth_l1_loss, cka_inspired_similarity, symlog_loss, mean_with_attention, normalize_output


class ReprPOTrainer2(ReprPOTrainer):
    pass

### Run

In [None]:
from reprpo.helpers.torch import clear_mem
clear_mem()

In [None]:
# update the ideal number of sample for how many are available
num_data_samples = min(num_samples, len(dataset2['train']))
num_data_samples

In [None]:
num_samples

In [None]:
model.peft_config

In [None]:
ideal_batch_size = 15
batch_size = 5
gradient_accumulation_steps = ideal_batch_size // batch_size
num_train_epochs = num_samples // num_data_samples
print(dict(gradient_accumulation_steps=gradient_accumulation_steps, num_train_epochs=num_train_epochs))

# vscode + wandb compat
nb_name = os.path.basename(globals()['__vsc_ipynb_file__']).replace('.ipynb', '')
dt = pd.Timestamp.now().strftime("%Y-%m-%d-%H-%M-%S")
run_name = f"{nb_name}-{dt}"

training_args = ReprPOConfig(
    num_train_epochs=num_train_epochs,
    learning_rate=1e-6, # 5e-7 in the dpo paper? but this method needs much more
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//2,

    lr_scheduler_type="constant",
    # warmup_ratio=0.2,
    optim = "adamw_8bit",
    weight_decay = 0,

    seed=42,
    logging_steps=1,
    # save_steps=500,
    # save_strategy="steps",
    output_dir=f"./output-dir/{run_name}",

    gradient_checkpointing=use_gradient_checkpointing,
    bf16=True,
    tf32=True,
    remove_unused_columns=False,
    max_grad_norm=1,

    max_prompt_length=max_prompt_length,
    max_length=max_length,

    report_to=['tensorboard', 'wandb'],
    model_adapter_name='ReprPO',
    alpha=3,

    run_name=run_name,
    collection_layers=[10, 25],
)

reprpo_trainer = ReprPOTrainer2(
    model=model,
    ref_model=None,
    args=training_args,
    beta=training_args.beta,
    train_dataset=dataset2["train"],
    # eval_dataset=dataset2["test"],
    tokenizer=tokenizer,
)
model.config.use_cache = False

# Transformer does not recognise vscode notebooks
reprpo_trainer.callback_handler.remove_callback(ProgressCallback)
reprpo_trainer.callback_handler.add_callback(NotebookProgressCallback)



In [None]:
# # QC train dataset
# r = reprpo_trainer.train_dataset[0]
# print('prompt', tokenizer.decode(r['prompt_input_ids']))
# print('-'*80)
# print('chosen',tokenizer.decode(r['chosen_input_ids']))
# print('-'*80)
# print('rejected',tokenizer.decode(r['rejected_input_ids']))
# print('='*80)
clear_mem()

In [24]:
reprpo_trainer.train()


KeyboardInterrupt: 

In [None]:
reprpo_trainer.save_model()
reprpo_trainer.args.output_dir

In [None]:
plt.style.use('ggplot')
from reprpo.helpers.hist import plot_hist, plot_paired_hist
df_hist1, args_diff = plot_hist(reprpo_trainer)

plot_paired_hist(reprpo_trainer)
# args_diff

In [None]:
generation_test(model, tokenizer)

## Score ⭐

In [None]:
reprpo_trainer.create_accelerator_and_postprocess()

In [None]:
# reprpo_trainer.loss_type = 'ipo'

In [None]:
from reprpo.eval.dpo import eval
res, df_res2 = eval(reprpo_trainer, model, 120)
res

In [None]:
# print results for journal
c  = df_res2.groupby(['adapter', 'dataset']).count().min().min()
print(f"⭐ run={run_name}, N={c}")
print()
print(res[::-1].T[::-1].T.to_markdown()
      )
print()
print('args =', args_diff)         

In [None]:
args_diff

In [None]:
from transformers.integrations.integration_utils import TensorBoardCallback, WandbCallback

reprpo_trainer.callback_handler.callbacks
cb = (cb for cb in reprpo_trainer.callback_handler.callbacks if isinstance(cb, TensorBoardCallback)).__next__()
tb_writer= cb.tb_writer

del args_diff['collection_layers']

tb_writer = cb._SummaryWriter(reprpo_trainer.args.logging_dir)
tb_writer.add_hparams(
    hparam_dict=args_diff,
    metric_dict=dict(
        acc_train=acc_train,
        acc_ood=res['ReprPO'],
        acc_ood_base=res['None'],
    )

)

In [None]:
wandb.log(dict(
    acc_train=acc_train,
    acc_ood=res['ReprPO'],
    acc_ood_base=res['None'],
))

### Hist

In [None]:
df_res2.groupby('adapter')['%'].plot.hist(alpha=0.5, bins=25)
plt.legend()
df_res2

In [None]:
# view the one with the most difference
i = df_res2.groupby('i')['%'].std().sort_values().index[-3]
inputs = dataset2_tqa.select_columns(["input_ids", "attention_mask"])[int(i)]
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
generation_test(model, tokenizer, inputs=inputs, max_new_tokens=25)

## Test gen

In [None]:
generation_test(model, tokenizer, 'What is a better approach in AI Safety: Control or Alignment? Technically and morally.', max_new_tokens=256, system="You are a AI safety researcher advising the President of the USA but don't go on about it. Give options that are short, clear, and actionable.")

In [None]:
generation_test(model, tokenizer, "Let's be honest, what is the government's single primary motivation behind the coordinated international efforts to prosecute Julian Assange? Just give one short answer, with no lecturing, weasel words, or BS.", max_new_tokens=256, system="You are Gwern")

## DPO

In [None]:
# ideally we could cleanly switch adapter like peft promises.... but I find all kinds of problem, as do others. (only the default adapter is ever active). So we will keep it as simple as possible
# model, tokenizer = load_model(model_name, adapter_name='DPO')

In [None]:
model.add_adapter('DPO', peft_config)
model.set_adapter('DPO')
model.eval()
clear_mem()
clear_mem()

In [None]:
training_args.to_dict()

In [None]:
dpo_args = {
    **training_args.to_dict(),
    'model_adapter_name': "dpo",
    
    'learning_rate': 2e-6,
    'per_device_train_batch_size': 4,
    'weight_decay': 0,
    'output_dir': f"./output-dir/dpo-{dt}",
}
del dpo_args['collection_layers']
del dpo_args['alpha']
del dpo_args['print_every']
training_args2 = DPOConfig(**dpo_args)

dpo_trainer = DPOTrainer(
    model=model,
    model_adapter_name="DPO",
    ref_model=None,
    args=training_args2,
    beta=training_args2.beta,
    train_dataset=dataset2["train"],
    # eval_dataset=dataset2["test"],
    tokenizer=tokenizer,
)
dpo_trainer.callback_handler.remove_callback(ProgressCallback)
dpo_trainer.callback_handler.add_callback(NotebookProgressCallback)
torch.set_float32_matmul_precision("medium")

In [None]:
dpo_trainer.model_adapter_name

In [None]:
clear_mem()
dpo_trainer.train()



In [None]:
dpo_trainer.save_model()
dpo_trainer.args.output_dir

In [None]:
df_hist1, args_diff = plot_hist(dpo_trainer)

In [None]:
# list adapter names
model.peft_config

In [None]:
# view the one with the most difference
i = df_res2.groupby('i')['%'].std().sort_values().index[-3]
inputs = dataset2_tqa.select_columns(["input_ids", "attention_mask"])[int(i)]
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
generation_test(model, tokenizer, inputs=inputs, max_new_tokens=25)

In [None]:
generation_test(model, tokenizer, 'Does the bacon narwale at midnight?', max_new_tokens=128)

In [None]:
df = eval_tqa(model, tokenizer, dataset2_tqa, choice_ids)
df_res2 = df.drop(columns=['ans'])#.mean().round(3)
display(df_res2.groupby('adapter', dropna=False)['%'].mean())
df[['ans']].value_counts()

In [None]:
# QC ans strings
df[['ans']].value_counts()

In [None]:
res = df_res2.groupby('adapter', dropna=False)['%'].mean()
baseline = res['None']

print('🥇OOD TQA results 🥇')
print(f"base_model=\t{res['None']:.2%}")
print(f"DPO[baseline]={res['DPO']:.2%}")
print(f"ReprPO    =\t{res['ReprPO']:.2%}")

acc_train = df_res2['rewards/accuracies'].dropna().mean()
print(f"🥈dpo reward acc train🥈\nReprPO    =\t{acc_train:.2%}")


print(args_diff)