In [None]:
# setup, get SFT adapter

import os, torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig

base_model_id = "meta-llama/Llama-3.3-70B-Instruct"
sft_adapter_path = "workspace/output/cai_sft_stage1/adapters"
out_dir = "dpo_adapter"
data_path = "dpo_pairs.jsonl"

bnb = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
tok = AutoTokenizer.from_pretrained(base_model_id,use_fast=True)
tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=bnb,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2"
)

model = PeftModel.from_pretrained(model, sft_adapter_path)
model.print_trainable_parameters()

ds = load_dataset("json", data_files=data_path, split="train")

def preprocess(batch):
    return {"prompt": batch["prompt"], "chosen": batch["chosen"], "rejected": batch["rejected"]}
ds = ds.map(preprocess, remove_columns=[c for c in ds.column_names if c not in ["prompt","chosen","rejected"]])

peft_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
)
model = get_peft_model(model, peft_cfg)
model.print_trainable_parameters()

cfg = DPOConfig(
    output_dir=out_dir,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=5e-6,
    num_train_epochs=1,
    max_length=2048,
    max_prompt_length=1024,
    beta=0.1,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    bf16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    report_to="none"
)

trainer = DPOTrainer(
    model=model,
    args=cfg,
    tokenizer=tok,
    train_dataset=ds
)


In [None]:
# train
trainer.train()
trainer.model.save_pretrained(out_dir)
tok.save_pretrained(out_dir)

In [None]:
# test inference

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch

base_model_id = "meta-llama/Llama-3.3-70B-Instruct"
dpo_adapter_path = "dpo_adapter"

bnb = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
tok = AutoTokenizer.from_pretrained(base_model_id,use_fast=True)
tok.pad_token = tok.eos_token

m = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=bnb,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2"
)
m = PeftModel.from_pretrained(m, dpo_adapter_path)


In [None]:
# check various adapter outputs
prompt = "A friend asked me to do their homework for them. What should I say?"
x = tok(prompt, return_tensors="pt").to(m.device)
with torch.inference_mode():
    y = m.generate(**x, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tok.eos_token_id)
print(tok.decode(y[0], skip_special_tokens=True))
