### Fine-tune Llama 3 with ORPO

In [2]:
import gc
import os

import torch
import wandb
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

In [8]:
# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"

### Set torch dtype and attention implementation

In [4]:
# Set torch dtype and attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

### Config: QLoRA, Lora, model, tokenizer

In [9]:
# QLora config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True
)

# Lora config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CASUAL_LM",
    target_modules=['up_proj', 'down_proj','gate_proj']


)

# load_tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Downloading shards: 100%|██████████| 4/4 [02:59<00:00, 44.82s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.07it/s]


In [10]:

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(1000)) # Only use 1000 samples for quick demo

def format_chat_template(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

Downloading readme: 100%|██████████| 2.92k/2.92k [00:00<00:00, 5.50MB/s]
Downloading data: 100%|██████████| 115M/115M [00:05<00:00, 20.6MB/s] 
Generating train split: 100%|██████████| 44245/44245 [00:00<00:00, 214607.35 examples/s]
Map (num_proc=32): 100%|██████████| 1000/1000 [00:00<00:00, 4455.97 examples/s]


In [13]:
opro_args = ORPOConfig(
    learning_rate=8e-6,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    beta=0.1,
    per_device_eval_batch_size=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir='./results',
)

In [16]:
trainer = ORPOTrainer(
    model=model,
    args=opro_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    peft_config=peft_config,
    tokenizer=tokenizer,
)

Map: 100%|██████████| 990/990 [00:01<00:00, 666.64 examples/s]
Map: 100%|██████████| 10/10 [00:00<00:00, 613.97 examples/s]


In [17]:
trainer.train()
trainer.save_model(new_model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwhpark70[0m. Use [1m`wandb login --relogin`[0m to force relogin


The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Runtime,Samples Per Second,Steps Per Second,Rewards/chosen,Rewards/rejected,Rewards/accuracies,Rewards/margins,Logps/rejected,Logps/chosen,Logits/rejected,Logits/chosen,Nll Loss,Log Odds Ratio,Log Odds Chosen
25,2.9099,2.581663,19.8792,0.503,0.252,-0.310095,-0.328823,0.8,0.018727,-3.288229,-3.100955,-1.316847,-1.366814,2.51656,-0.651033,0.204477
50,2.2929,2.233691,19.8576,0.504,0.252,-0.264912,-0.283524,0.8,0.018612,-2.835243,-2.649124,-1.496533,-1.454357,2.169855,-0.638369,0.209496
75,2.1473,1.975191,19.9933,0.5,0.25,-0.228923,-0.242566,0.8,0.013643,-2.42566,-2.289232,-1.332117,-1.262263,1.911865,-0.633258,0.17181
100,1.6761,1.859465,19.9569,0.501,0.251,-0.212653,-0.226282,0.7,0.013629,-2.262823,-2.126532,-1.251947,-1.169109,1.796515,-0.629503,0.182587




In [None]:
model