In [None]:
!conda install -c conda-forge pytorch=2.2.0 torchvision=0.17.0 torchtext=0.17.2 cpuonly numpy pandas matplotlib scikit-learn tqdm -y

In [None]:
!pip install datasets==2.20.0 trl==0.9.6 transformers==4.42.3 peft==0.11.1 sacrebleu==2.4.2 evaluate==0.4.2

 # Preference Tuning with DPO

When we use the Hugging Face stack, preference tuning is eerily similar to
the instruction tuning we covered before with some slight differences. We
will still be using TinyLlama but this time an instruction-tuned version that
was first trained using full fine-tuning and then further aligned with DPO.
Compared to our initial instruction-tuned model, this LLM was trained on
much larger datasets.
In this section, we will demonstrate how you can further align this model
using DPO with reward-based datasets.

In [2]:
from datasets import load_dataset


def format_prompt(example):
    """Format the prompt to using the <|user|> template TinyLLama
    is using"""
    # Format answers
    system = "<|system|>\n" + example["system"] + "</s>\n"
    prompt = "<|user|>\n" + example["input"] + " </s>\n<|assistant|>\n"
    chosen = example["chosen"] + "</s>\n"
    rejected = example["rejected"] + "</s>\n"
    return {
        "prompt": system + prompt,
        "chosen": chosen,
        "rejected": rejected,
    }


# Apply formatting to the dataset and select relatively short answers
dpo_dataset = load_dataset("argilla/distilabel-intel-orca-dpo-pairs", split="train")
dpo_dataset = dpo_dataset.filter(
    lambda r: 
        r["status"] != "tie"
        and r["chosen_score"] >= 8
        and not r["in_gsm8k_train"]
)
dpo_dataset = dpo_dataset.map(format_prompt, remove_columns=dpo_dataset.column_names)
dpo_dataset

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 5922
})

Note that we apply additional filtering to further reduce the size of the data
to roughly 6,000 examples from the original 13,000 examples.

# Model Quantization

We load our base model and load it with the LoRA we created previously.
As before, we quantize the model to reduce the necessary VRAM for
training:

In [None]:
from transformers import (
    BitsAndBytesConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
)

# 4-bit quantization configuration - Q in QLoRA
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Use 4-bit precision model loading
    bnb_4bit_quant_type="nf4",  # Quantization type
    bnb_4bit_compute_dtype="float16",  # Compute dtype
    bnb_4bit_use_double_quant=True,  # Apply nested quantization
)

# Load LLaMA tokenizer
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = "<PAD>"
tokenizer.padding_side = "left"

# Merge LoRA and base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    trust_remote_code=True
)
model.config.use_cache = False
model.config.pretraining_tp = 1

Next, we use the same LoRA configuration as before to perform the DPO
training:

In [None]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

# Prepare LoRA configuration
peft_config = LoraConfig(
    lora_alpha=32,  # LoRA Scaling
    lora_dropout=0.05,  # Dropout for LoRA Layers
    r=64,  # Rank
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[  # Layers to target
        "k_proj",
        "gate_proj",
        "v_proj",
        "up_proj",
        "q_proj",
        "o_proj",
        "down_proj",
    ],
)
# prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

# Training Configuration

For the sake of simplicity, we will use the same training arguments as we
did before with one difference. Instead of running for a single epoch (which
can take up to two hours), we run for 200 steps instead for illustration
purposes. Moreover, we added the warmup_ratio parameter, which
increases the learning rate from 0 to the learning_rate value we set for
the first 10% of steps. By maintaining a small learning rate at the start (i.e.,warmup period), we allow the model to adjust to the data before applying
larger learning rates, therefore avoiding harmful divergence:

In [None]:
from trl import DPOConfig

output_dir = "./dpo-tinyllama"

# Training arguments
training_arguments = DPOConfig(
    output_dir=output_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_32bit",
    learning_rate=1e-5,
    lr_scheduler_type="cosine",
    max_steps=200,
    logging_steps=10,
    fp16=True,
    gradient_checkpointing=True,
    warmup_ratio=0.1,
    beta=0.1,
    label_smoothing=0.0,
    loss_type="sigmoid",
)

# Training

Now that we have prepared all our models and parameters, we can start
fine-tuning our model:

In [None]:
from trl import DPOTrainer

# Create DPO trainer
dpo_trainer = DPOTrainer(
    model, 
    ref_model=None, # The reference model (not used in this case because LoRA has been used)
    args=training_arguments,
    train_dataset=dpo_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
    beta=0.1,
    max_prompt_length=512,
    max_length=512
)


Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
Tokenizing train dataset:   0%|          | 0/5922 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2055 > 2048). Running this sequence through the model will result in indexing errors
Tokenizing train dataset: 100%|██████████| 5922/5922 [00:21<00:00, 278.13 examples/s]
max_steps is given, it will override any value given in num_train_epochs


In [7]:
# Start fine tuning 
dpo_trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed
  5%|▌         | 10/200 [01:05<20:09,  6.36s/it]

{'loss': 0.6913, 'grad_norm': 2.3933210372924805, 'learning_rate': 5e-06, 'rewards/chosen': 0.00039761661901138723, 'rewards/rejected': -0.00327131524682045, 'rewards/accuracies': 0.30000001192092896, 'rewards/margins': 0.0036689317785203457, 'logps/rejected': -91.34730529785156, 'logps/chosen': -83.61310577392578, 'logits/rejected': -3.0855438709259033, 'logits/chosen': -3.0500998497009277, 'epoch': 0.01}


 10%|█         | 20/200 [02:10<19:56,  6.64s/it]

{'loss': 0.673, 'grad_norm': 2.6381709575653076, 'learning_rate': 1e-05, 'rewards/chosen': -0.003814463736489415, 'rewards/rejected': -0.046028126031160355, 'rewards/accuracies': 0.4625000059604645, 'rewards/margins': 0.04221365973353386, 'logps/rejected': -130.34994506835938, 'logps/chosen': -98.92628479003906, 'logits/rejected': -3.1642000675201416, 'logits/chosen': -3.0853583812713623, 'epoch': 0.03}


 15%|█▌        | 30/200 [03:18<19:03,  6.73s/it]

{'loss': 0.635, 'grad_norm': 2.0923023223876953, 'learning_rate': 9.924038765061042e-06, 'rewards/chosen': -0.02209140732884407, 'rewards/rejected': -0.1545577049255371, 'rewards/accuracies': 0.4625000059604645, 'rewards/margins': 0.13246631622314453, 'logps/rejected': -115.07453918457031, 'logps/chosen': -81.56785583496094, 'logits/rejected': -3.1240222454071045, 'logits/chosen': -3.073723316192627, 'epoch': 0.04}


 20%|██        | 40/200 [04:14<15:50,  5.94s/it]

{'loss': 0.583, 'grad_norm': 1.7409669160842896, 'learning_rate': 9.727592877996585e-06, 'rewards/chosen': -0.05093228071928024, 'rewards/rejected': -0.35245656967163086, 'rewards/accuracies': 0.512499988079071, 'rewards/margins': 0.30152428150177, 'logps/rejected': -132.6724853515625, 'logps/chosen': -96.96385955810547, 'logits/rejected': -3.091298818588257, 'logits/chosen': -3.069164991378784, 'epoch': 0.05}


 25%|██▌       | 50/200 [05:21<15:55,  6.37s/it]

{'loss': 0.5904, 'grad_norm': 3.753676176071167, 'learning_rate': 9.37309853569698e-06, 'rewards/chosen': -0.11603733152151108, 'rewards/rejected': -0.4164137840270996, 'rewards/accuracies': 0.48750001192092896, 'rewards/margins': 0.30037641525268555, 'logps/rejected': -133.6261749267578, 'logps/chosen': -104.56297302246094, 'logits/rejected': -3.1330058574676514, 'logits/chosen': -3.0932929515838623, 'epoch': 0.07}


 30%|███       | 60/200 [06:28<16:12,  6.95s/it]

{'loss': 0.5998, 'grad_norm': 3.7886803150177, 'learning_rate': 8.885729807284855e-06, 'rewards/chosen': -0.10002808272838593, 'rewards/rejected': -0.48056063055992126, 'rewards/accuracies': 0.3499999940395355, 'rewards/margins': 0.38053256273269653, 'logps/rejected': -103.0716781616211, 'logps/chosen': -82.37723541259766, 'logits/rejected': -3.174006938934326, 'logits/chosen': -3.160590410232544, 'epoch': 0.08}


 35%|███▌      | 70/200 [07:35<14:42,  6.79s/it]

{'loss': 0.5911, 'grad_norm': 1.6772764921188354, 'learning_rate': 8.280295144952537e-06, 'rewards/chosen': -0.1612970530986786, 'rewards/rejected': -0.6325763463973999, 'rewards/accuracies': 0.3499999940395355, 'rewards/margins': 0.4712792932987213, 'logps/rejected': -113.21795654296875, 'logps/chosen': -86.12324523925781, 'logits/rejected': -3.2088027000427246, 'logits/chosen': -3.185551166534424, 'epoch': 0.09}


 40%|████      | 80/200 [08:43<14:10,  7.08s/it]

{'loss': 0.5207, 'grad_norm': 2.0901639461517334, 'learning_rate': 7.575190374550272e-06, 'rewards/chosen': -0.14294259250164032, 'rewards/rejected': -0.889349102973938, 'rewards/accuracies': 0.48750001192092896, 'rewards/margins': 0.7464064359664917, 'logps/rejected': -139.4697723388672, 'logps/chosen': -108.7026596069336, 'logits/rejected': -3.104055404663086, 'logits/chosen': -3.0495753288269043, 'epoch': 0.11}


 45%|████▌     | 90/200 [09:50<11:58,  6.54s/it]

{'loss': 0.533, 'grad_norm': 4.759432792663574, 'learning_rate': 6.7918397477265e-06, 'rewards/chosen': -0.1827918142080307, 'rewards/rejected': -0.9802249670028687, 'rewards/accuracies': 0.4375, 'rewards/margins': 0.7974331974983215, 'logps/rejected': -148.990234375, 'logps/chosen': -91.01962280273438, 'logits/rejected': -3.1768617630004883, 'logits/chosen': -3.1311748027801514, 'epoch': 0.12}


 50%|█████     | 100/200 [10:57<11:39,  6.99s/it]

{'loss': 0.6635, 'grad_norm': 2.0014357566833496, 'learning_rate': 5.954044976882725e-06, 'rewards/chosen': -0.3962658941745758, 'rewards/rejected': -1.1358253955841064, 'rewards/accuracies': 0.4625000059604645, 'rewards/margins': 0.7395597696304321, 'logps/rejected': -154.74903869628906, 'logps/chosen': -100.98489379882812, 'logits/rejected': -3.1616950035095215, 'logits/chosen': -3.1237246990203857, 'epoch': 0.14}


 55%|█████▌    | 110/200 [12:05<10:08,  6.76s/it]

{'loss': 0.4963, 'grad_norm': 1.7281523942947388, 'learning_rate': 5.087262032186418e-06, 'rewards/chosen': -0.21976308524608612, 'rewards/rejected': -1.3686443567276, 'rewards/accuracies': 0.5375000238418579, 'rewards/margins': 1.148881196975708, 'logps/rejected': -166.9852752685547, 'logps/chosen': -115.24952697753906, 'logits/rejected': -3.1242799758911133, 'logits/chosen': -3.0581512451171875, 'epoch': 0.15}


 60%|██████    | 120/200 [13:17<09:35,  7.19s/it]

{'loss': 0.5204, 'grad_norm': 6.5013532638549805, 'learning_rate': 4.217827674798845e-06, 'rewards/chosen': -0.2142508327960968, 'rewards/rejected': -1.2666479349136353, 'rewards/accuracies': 0.512499988079071, 'rewards/margins': 1.0523971319198608, 'logps/rejected': -162.53964233398438, 'logps/chosen': -94.13519287109375, 'logits/rejected': -3.089872360229492, 'logits/chosen': -2.995034694671631, 'epoch': 0.16}


 65%|██████▌   | 130/200 [14:18<07:13,  6.20s/it]

{'loss': 0.581, 'grad_norm': 2.537200927734375, 'learning_rate': 3.372159227714218e-06, 'rewards/chosen': -0.2623863220214844, 'rewards/rejected': -0.9639021754264832, 'rewards/accuracies': 0.4124999940395355, 'rewards/margins': 0.7015158534049988, 'logps/rejected': -120.1742935180664, 'logps/chosen': -96.2881851196289, 'logits/rejected': -3.133657932281494, 'logits/chosen': -3.115107297897339, 'epoch': 0.18}


 70%|███████   | 140/200 [15:24<07:02,  7.05s/it]

{'loss': 0.5806, 'grad_norm': 7.107076644897461, 'learning_rate': 2.5759518987683154e-06, 'rewards/chosen': -0.22886653244495392, 'rewards/rejected': -1.026235818862915, 'rewards/accuracies': 0.4124999940395355, 'rewards/margins': 0.7973693013191223, 'logps/rejected': -114.10166931152344, 'logps/chosen': -101.51978302001953, 'logits/rejected': -3.173781156539917, 'logits/chosen': -3.1216189861297607, 'epoch': 0.19}


 75%|███████▌  | 150/200 [16:26<05:12,  6.26s/it]

{'loss': 0.5813, 'grad_norm': 0.8339665532112122, 'learning_rate': 1.8533980447508138e-06, 'rewards/chosen': -0.17669036984443665, 'rewards/rejected': -1.162750244140625, 'rewards/accuracies': 0.38749998807907104, 'rewards/margins': 0.986059844493866, 'logps/rejected': -118.98201751708984, 'logps/chosen': -57.67116165161133, 'logits/rejected': -3.0821642875671387, 'logits/chosen': -3.0392024517059326, 'epoch': 0.2}


 80%|████████  | 160/200 [17:30<04:22,  6.56s/it]

{'loss': 0.5835, 'grad_norm': 2.645308256149292, 'learning_rate': 1.22645209888614e-06, 'rewards/chosen': -0.2559385895729065, 'rewards/rejected': -1.1220470666885376, 'rewards/accuracies': 0.42500001192092896, 'rewards/margins': 0.8661085367202759, 'logps/rejected': -122.52392578125, 'logps/chosen': -82.15196228027344, 'logits/rejected': -3.149512529373169, 'logits/chosen': -3.1040146350860596, 'epoch': 0.22}


 85%|████████▌ | 170/200 [18:35<03:22,  6.74s/it]

{'loss': 0.5778, 'grad_norm': 2.8055620193481445, 'learning_rate': 7.141634964894389e-07, 'rewards/chosen': -0.15658487379550934, 'rewards/rejected': -0.9607060551643372, 'rewards/accuracies': 0.3499999940395355, 'rewards/margins': 0.804121196269989, 'logps/rejected': -107.6840591430664, 'logps/chosen': -76.82438659667969, 'logits/rejected': -3.122872829437256, 'logits/chosen': -3.1170833110809326, 'epoch': 0.23}


 90%|█████████ | 180/200 [19:38<02:01,  6.10s/it]

{'loss': 0.6247, 'grad_norm': 5.650725364685059, 'learning_rate': 3.320978675139919e-07, 'rewards/chosen': -0.3278573155403137, 'rewards/rejected': -1.231153964996338, 'rewards/accuracies': 0.4124999940395355, 'rewards/margins': 0.9032966494560242, 'logps/rejected': -126.82087707519531, 'logps/chosen': -102.22705841064453, 'logits/rejected': -3.1422648429870605, 'logits/chosen': -3.146191120147705, 'epoch': 0.24}


 95%|█████████▌| 190/200 [20:39<01:01,  6.18s/it]

{'loss': 0.6626, 'grad_norm': 6.409192085266113, 'learning_rate': 9.186408276168012e-08, 'rewards/chosen': -0.24232156574726105, 'rewards/rejected': -0.5975373387336731, 'rewards/accuracies': 0.32499998807907104, 'rewards/margins': 0.35521575808525085, 'logps/rejected': -101.39582824707031, 'logps/chosen': -79.71857452392578, 'logits/rejected': -3.0484092235565186, 'logits/chosen': -3.01442551612854, 'epoch': 0.26}


100%|██████████| 200/200 [21:42<00:00,  6.12s/it]

{'loss': 0.5317, 'grad_norm': 3.9176714420318604, 'learning_rate': 7.615242180436521e-10, 'rewards/chosen': -0.16919878125190735, 'rewards/rejected': -1.207228183746338, 'rewards/accuracies': 0.4625000059604645, 'rewards/margins': 1.0380291938781738, 'logps/rejected': -134.67039489746094, 'logps/chosen': -79.00543975830078, 'logits/rejected': -3.1448285579681396, 'logits/chosen': -3.0928685665130615, 'epoch': 0.27}


100%|██████████| 200/200 [21:45<00:00,  6.53s/it]

{'train_runtime': 1305.2724, 'train_samples_per_second': 1.226, 'train_steps_per_second': 0.153, 'train_loss': 0.5910451936721802, 'epoch': 0.27}





TrainOutput(global_step=200, training_loss=0.5910451936721802, metrics={'train_runtime': 1305.2724, 'train_samples_per_second': 1.226, 'train_steps_per_second': 0.153, 'total_flos': 0.0, 'train_loss': 0.5910451936721802, 'epoch': 0.2701789935832489})

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

log = pd.DataFrame(dpo_trainer.state.log_history)
log_t = log[log["loss"].notna()]
log_e = log[log["eval_loss"].notna()]

# Plot train and evaluation loss
plt.plot(log_t["epoch"], log_t["loss"], label="Train")
plt.plot(log_e["epoch"], log_e["eval_loss"], label="Eval")
plt.title("Model Losses")
plt.legend()
plt.show()

### Notes

Using `save_model()` we save the LoRA adapters alongside the models. This means that `output_dir` will contain both the model weights and LoRA adapters as different modules. 

The function `merge_and_unload()` create a standalone model by merging the model weights with the LoRA adapters. This step also free precious VRAM by unloading the LoRA adapters which are no longer necessary.

In [None]:
# Save the adapter 
dpo_trainer.save_model(output_dir)

# Merge the base model weights to LoRA adapters
# Also free vram by unloading LoRA adapters 
merged_model = model.merge_and_unload()

In [None]:
"""from peft import PeftModel

# Load the model with the LoRA adapters active
dpo_model = PeftModel.from_pretrained(
    model,
    output_dir,
    device_map="auto",
)

merged_model = dpo_model.merge_and_unload()"""



In [None]:
# Save model and tokenizer
merged_dir = "./dpo-tinyllama-merged"

merged_model.save_pretrained(merged_dir)
tokenizer.save_pretrained(merged_dir)

('./dpo-tinyllama-merged/tokenizer_config.json',
 './dpo-tinyllama-merged/special_tokens_map.json',
 './dpo-tinyllama-merged/tokenizer.json')

# Testing inference

Let's load our fine-tuned model as we would do in a production pipeline.

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
)

model_dir = "./dpo-tinyllama-merged"

finetuned_model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    device_map="auto",
    trust_remote_code=True,
)

finetuned_tokenizer = AutoTokenizer.from_pretrained(
    model_dir,
    trust_remote_code=True
)

pipe = pipeline(
    "text-generation",
    model=finetuned_model,
    tokenizer=finetuned_tokenizer,
    max_new_tokens=256,
    temperature=0.7,
    do_sample=True,
)

prompt = "Explain quantum computing in simple terms."
print(pipe(prompt)[0]["generated_text"])

Explain quantum computing in simple terms. Let's say you have to find the square roots of numbers. In classical computing, you could use a calculator to do this, but quantum computers can solve this problem in a fraction of the time. How does quantum computing work?

A: Quantum computing works by using quantum bits, or qubits, which can take on both a 0 or a 1, and can exist in multiple states at the same time. A quantum computer can solve problems that are too computationally intensive for classical computers by using a technique called quantum entanglement.

When two qubits are entangled, they can interact with each other in ways that are not possible for classical bits. If two qubits are entangled, they can perform a quantum computation by acting on one qubit simultaneously with the other qubit. This means that if one qubit performs a calculation, the other qubit can act on it simultaneously, either with a 0 or with a 1, and the result will be the same.

In classical computers, the 