In [None]:
! nvidia-smi

In [None]:
! pip install -U --quiet datasets evaluate torch transformers accelerate peft

In [None]:
! pip install -q git+https://github.com/huggingface/trl.git

### **Load Dataset**

In [None]:
from datasets import load_dataset

clair_apo = load_dataset("ContextualAI/ultrafeedback_clair_32k")
clair_apo

In [None]:
for i in range(3):
  print(clair_apo["train"][i]["prompt"])
  print(clair_apo["train"][i]["chosen"])
  print(clair_apo["train"][i]["rejected"])
  print("\n\n")

### **Load Model**

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "rasyosef/phi-2-sft-openhermes-128k-v2-merged"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="cuda",
    # attn_implementation="flash_attention_2"
  )

In [None]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
print(model)

In [None]:
messages = [{"role":"user", "content":"Who was the last king of Germany?"}]

def chat(messages, max_new_tokens=8):
  tokenized_messages = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
  outputs = model.generate(tokenized_messages, max_new_tokens=max_new_tokens)
  print(tokenizer.decode(outputs[0]))

chat(messages, max_new_tokens=128)

### **Inspect Dataset**

In [None]:
clair_apo

In [None]:
def preprocess_dataset(row):
    prompt = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False, add_generation_prompt=True)
    chosen = row["chosen"][-1]["content"] + tokenizer.eos_token
    rejected = row["rejected"][-1]["content"] + tokenizer.eos_token

    prompt_length = len(tokenizer.tokenize(prompt))
    chosen_length = len(tokenizer.tokenize(chosen))
    rejected_length = len(tokenizer.tokenize(rejected))

    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected,
        "prompt_length": prompt_length,
        "chosen_length": chosen_length,
        "rejected_length": rejected_length,
    }

clair_apo_processed = clair_apo.map(lambda row: preprocess_dataset(row), num_proc=4)
clair_apo_processed

In [None]:
# Lengths Distribution
prompt_lengths = sorted(clair_apo_processed["train"]["prompt_length"])
chosen_lengths = sorted(clair_apo_processed["train"]["chosen_length"])
rejected_lengths = sorted(clair_apo_processed["train"]["rejected_length"])

print("prompt_lengths:", prompt_lengths[1024], prompt_lengths[4096], prompt_lengths[8000], prompt_lengths[12000], max(prompt_lengths))
print("chosen_lengths:", chosen_lengths[1024], chosen_lengths[4096], chosen_lengths[8000], chosen_lengths[12000], max(chosen_lengths))
print("rejected_lengths:", rejected_lengths[1024], rejected_lengths[4096], rejected_lengths[8000], rejected_lengths[12000], max(rejected_lengths))

In [None]:
MAX_LENGTH = 512
clair_apo_filtered = clair_apo_processed.filter(lambda example: example['prompt_length'] + example['chosen_length'] < MAX_LENGTH and example['prompt_length'] + example['rejected_length'] < MAX_LENGTH)
clair_apo_filtered

In [None]:
import random
random.seed(42)

NUM_SAMPLES = 10_000
clair_apo_final = clair_apo_filtered['train'].shuffle(seed=42).select(range(NUM_SAMPLES))
clair_apo_final = clair_apo_final.train_test_split(test_size=0.02,seed=42)
clair_apo_final

In [None]:
sample = clair_apo_final["train"].shuffle().select(range(5))

for row in sample:
  print(row["prompt"])
  print(row["chosen"])
  print(row["rejected"])
  print("\n-----------------------------------------------------\n")

### **APO with TRL**

In [None]:
from peft import LoraConfig, get_peft_model, cast_mixed_precision_params

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    # Target all linear layers
    target_modules=["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2", "lm_head"]
)

model = get_peft_model(model, peft_config)
cast_mixed_precision_params(model, dtype=torch.float16)
model.print_trainable_parameters()

In [None]:
from trl import DPOConfig, DPOTrainer

batch_size = 1
gradient_accum_steps = 8
epochs = 2

new_model_id = "phi-2-apo"

eval_steps = 250 #len(combined_dpo_final["train"]) // (batch_size * gradient_accum_steps * 8)
save_steps = 250
logging_steps=eval_steps

print("Eval Steps:", eval_steps)
print("Save Steps:", save_steps)

dpo_config = DPOConfig(
  output_dir=new_model_id,
  beta=0.1,
  max_length=512,
  max_prompt_length=512,
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  gradient_accumulation_steps=gradient_accum_steps,
  num_train_epochs=epochs,
  learning_rate=1e-6,
  warmup_steps=0,
  lr_scheduler_type="cosine",
  remove_unused_columns=False,
  fp16=True,
  logging_strategy="steps",
  logging_steps=logging_steps,
  eval_strategy="steps",
  eval_steps=eval_steps,
  save_strategy="steps",
  save_steps=save_steps,
  seed=42,
  loss_type="apo_zero",
  # Optimization Params
  # gradient_checkpointing=True,
  # hub_token=userdata.get("HF_TOKEN") # Your HuggingFace token
  # push_to_hub=True
)

In [None]:
trainer = DPOTrainer(
    model, # left ref_model null
    args=dpo_config,
    train_dataset=clair_apo_final["train"],
    eval_dataset=clair_apo_final["test"],
    tokenizer=tokenizer
)

In [None]:
trainer.train()

In [None]:
def chat(messages):
    tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")

    outputs = model.generate(tokenized_chat, max_new_tokens=128) #, stopping_criteria=["<|im_end|>"])
    print(tokenizer.decode(outputs[0]))

messages = [{"role": "user", "content": "What is quantum computing?"}]
chat(messages)

In [None]:
# trainer.push_to_hub()