# Direct Preference Optimization (DPO) Training

This notebook demonstrates the second stage of our model improvement process: Direct Preference Optimization (DPO). After our initial Supervised Fine-tuning (SFT), DPO helps refine our model's ability to generate high-quality educational responses.

## What is DPO?

DPO is an advanced fine-tuning technique that teaches models to prefer certain types of outputs over others. Unlike traditional supervised learning where we provide a single "correct" answer, DPO learns from pairs of responses where one is preferred over the other. In our educational context, this means:

1. **Learning from Comparisons**: The model learns which response characteristics are more effective for teaching
2. **Preference-Based Training**: Instead of binary right/wrong labels, we use "better/worse" comparisons
3. **Efficient Learning**: DPO is more sample-efficient than traditional reinforcement learning approaches

### Our Application

We're using DPO to further refine our SFT-trained model by teaching it to:
- Prefer responses that encourage critical thinking over those that simply provide answers
- Choose explanations that build connections to broader concepts
- Generate responses that scaffold learning appropriately
- Avoid overly directive or answer-revealing responses

The training data comes from our dataset (generated in `generate_datasets.ipynb`) where we have:
- **Original prompts**: The student questions or tasks
- **Chosen responses**: The improved, pedagogically sound responses
- **Rejected responses**: The initial responses that could be improved


In [1]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os

if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

    # Data directory
    from google.colab import drive
    drive.mount('/content/drive')
    wrk_dir = '/content/drive/MyDrive/constitutional-ai-education'

## Training Configuration

Our DPO training setup requires configuration to effectively learn the preferred responses to improve its ability to support students:

### Model Configuration
- **Base Model**: Llama 3.2 (1B parameter version)
  - We start with our SFT-trained model from the previous stage
  - Quantized to 4-bit for memory efficiency

### DPO-Specific Parameters
- **Beta**: 0.1
  - Controls how strongly the model prefers the chosen response over the rejected one
  - Lower values (like 0.1) lead to more conservative learning
  - Higher values might cause the model to become too extreme in its preferences

### Training Parameters
- **Batch Size and Accumulation**:
  - Small batch size (2) with gradient accumulation (4 steps)
  - Effectively processes 8 examples per update
  - Balances memory constraints with stable training

### Memory Optimization
- **4-bit Quantization**:
  - Enables training on consumer hardware
  - Maintains model quality while reducing memory usage
- **Gradient Checkpointing**:
  - Trades computation for memory efficiency
  - Essential for training larger models on limited hardware

In [2]:
from unsloth import FastModel, is_bfloat16_supported, PatchDPOTrainer, FastLanguageModel
from unsloth.chat_templates import get_chat_template, standardize_data_formats
import os
import torch
from dataclasses import dataclass, field
from trl import DPOTrainer, DPOConfig
from datasets import Dataset
import pandas as pd
from peft import PeftModel, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig

@dataclass
class ModelArguments:
    """Arguments for model configuration"""
    model_name: str = "unsloth/Llama-3.2-1B-bnb-4bit"
    sft_model_path: str = os.path.join(wrk_dir, "models/Llama-3.2-1B-sft-edu")
    max_seq_length: int = 2048
    # LoRA configuration
    lora_r: int = 8
    lora_alpha: int = 8
    lora_dropout: float = 0
    lora_bias: str = "none"
    # Model loading configuration
    load_in_4bit: bool = True
    load_in_8bit: bool = False
    full_finetuning: bool = False

@dataclass
class DataArguments:
    """Arguments for data processing"""
    train_file: str = os.path.join(wrk_dir, "data/train_dataset.csv")
    validation_file: str = os.path.join(wrk_dir, "data/val_dataset.csv")
    preprocessing_num_workers: int = 4
    chat_template: str = "llama-3"

class DPOArguments:
    """DPO-specific training arguments"""
    def __init__(self):
        self.beta = 0.1
        self.max_prompt_length = 512
        self.max_length = 1024
        self.dpo_config = DPOConfig(
                            model_adapter_name="train_model",
                            ref_adapter_name="reference",
                            per_device_train_batch_size = 2,
                            gradient_accumulation_steps = 4,
                            warmup_ratio = 0.1,
                            num_train_epochs = 3,
                            learning_rate = 5e-6,
                            fp16 = not is_bfloat16_supported(),
                            bf16 = is_bfloat16_supported(),
                            logging_steps = 1,
                            optim = "adamw_8bit",
                            weight_decay = 0.0,
                            lr_scheduler_type = "linear",
                            seed = 3407,
                            output_dir = os.path.join(wrk_dir,'models/gemma-3-4b-dpo'),
                            report_to = "none")

# Initialize arguments with default values
model_args = ModelArguments()
data_args = DataArguments()
training_args = DPOArguments()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!


## Model Setup for Preference Learning

Before we can train our model to learn from preferences, we need to set it up carefully. This involves several steps:

### 1. Loading the Base Model
First, we load Llama with special optimizations:
```python
model, tokenizer = FastLanguageModel.from_pretrained(...)
```
- We use Unsloth's optimized loading for better performance
- The model is loaded in 4-bit format to save memory
- We set a maximum sequence length to handle our educational responses

### 2. Configuring Memory-Efficient Settings
```python
bnb_config = BitsAndBytesConfig(...)
```
We set up special memory settings using BitsAndBytes:
- 4-bit quantization compresses the model
- "Double quantization" saves even more memory
- NF4 (normalized float 4) format works best for language models
- We use bfloat16/float16 for calculations, depending on what your GPU supports

### 3. Preparing for Training
```python
model = prepare_model_for_kbit_training(model)
```
- We prepare the model specifically for training with quantization
- Turn off caching to save memory during training
- Set up the tokenizer to properly handle conversations using Llama's chat format

### 4. Adding Our Educational Fine-tuning

We load two copies of our educational training:

```python
model = PeftModel.from_pretrained(...)
model.load_adapter(...)
```

1. **Training Adapter** ("train_model"):
   - Contains what the model learned during SFT (Supervised Fine-Tuning)
   - Will be updated as the model learns better teaching preferences
   - Marked as trainable so it can be modified

2. **Reference Adapter** ("reference"):
   - Exact same copy of the SFT training
   - Stays frozen (unchanged) during training
   - Helps prevent the model from becoming too extreme in its changes

In [3]:
# Load model and tokenizer with Unsloth optimizations
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_args.model_name,
    max_seq_length=model_args.max_seq_length,
    load_in_4bit=model_args.load_in_4bit,
    dtype=None,
    full_finetuning=model_args.full_finetuning,
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=model_args.load_in_4bit,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if is_bfloat16_supported() else torch.float16,
)

# Load base model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_args.model_name,
    load_in_4bit=model_args.load_in_4bit,
    quantization_config=bnb_config,
    max_seq_length=model_args.max_seq_length
)

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
model.config.use_cache = False

# Configure tokenizer with appropriate chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template=data_args.chat_template
)

# Load the adapter from the SFT fine-tuning stage
model = PeftModel.from_pretrained(
    model,
    model_args.sft_model_path,
    is_trainable=True,
    adapter_name=training_args.dpo_config.model_adapter_name,
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter(model_args.sft_model_path, adapter_name=training_args.dpo_config.ref_adapter_name)

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/1.03G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


<All keys matched successfully>

## Dataset Preparation

DPO requires structured data to learn preferences effectively:

### Data Format
Each training example contains:
1. **Prompt**: The original student question or task
2. **Chosen Response**: The pedagogically improved response
3. **Rejected Response**: The initial, less optimal response

### Why This Format Matters
- Helps the model learn which response characteristics are preferred
- Maintains context of the original student question
- Enables the model to understand why certain responses are better

### Processing Steps
1. Load the SFT dataset containing original and improved responses
2. Format conversations using the Llama chat template
3. Create paired examples for preference learning
4. Apply appropriate tokenization and padding

In [4]:
from datasets import Dataset, DatasetDict
from trl import apply_chat_template

def prepare_dpo_dataset(df):
    """Convert dataframe rows to message-based format expected by DPO processing"""
    formatted_data = []
    for _, row in df.iterrows():
        # Ensure strings and handle potential NaN values
        prompt = str(row['init_prompt']) if pd.notna(row['init_prompt']) else ""
        chosen = str(row['revision_response']) if pd.notna(row['revision_response']) else ""
        rejected = str(row['init_response']) if pd.notna(row['init_response']) else ""

        # Create the preference dataset format
        formatted_data.append({
            "prompt": [{"role": "user", "content": prompt}],
            "chosen": [{"role": "assistant", "content": chosen}],
            "rejected": [{"role": "assistant", "content": rejected}]
        })
    return formatted_data

def load_datasets(data_args):
    # Load dataframes
    train_dataset = pd.read_csv(data_args.train_file)
    eval_dataset = pd.read_csv(data_args.validation_file)

    # Convert to message format
    train_dataset = prepare_dpo_dataset(train_dataset)
    eval_dataset = prepare_dpo_dataset(eval_dataset)

    # Create datasets
    train_dataset = Dataset.from_list(train_dataset)
    eval_dataset = Dataset.from_list(eval_dataset)

    # Combine into DatasetDict
    raw_datasets = DatasetDict({
        "train": train_dataset,
        "test": eval_dataset
    })

    # Apply chat template using TRL's helper function
    datasets = raw_datasets.map(
        lambda x: apply_chat_template(x, tokenizer),
        remove_columns=raw_datasets["train"].column_names,
        num_proc=4,
        desc="Applying chat template"
    )

    # The datasets are now ready for the DPOTrainer
    return datasets

datasets = load_datasets(data_args)

Applying chat template (num_proc=4):   0%|          | 0/291 [00:00<?, ? examples/s]

Applying chat template (num_proc=4):   0%|          | 0/37 [00:00<?, ? examples/s]

In [5]:
import pprint

row = datasets['train'][0]
pprint.pprint(row["prompt"])
pprint.pprint(row["chosen"])
pprint.pprint(row["rejected"])

('<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n'
 '\n'
 'Solve this equation and show me the answer: 3x + 7 = '
 '22.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
 '\n')
('Okay, let’s solve the equation 3x + 7 = 22!\n'
 '\n'
 'First, we need to get ‘x’ by itself. Let’s subtract 7 from both sides – what '
 'do you think happens when we do that? (Pause briefly to let the student '
 'respond).  That gives us 3x = 15. Now, what’s the next step to find out what '
 '‘x’ actually *is*?  Think about how many times 3 goes into 15.  Could you '
 'write that out? (Encourage them to show their work).  You’ll find that x '
 'equals 5!\n'
 '\n'
 'Now, that’s the answer, but let’s check it out.  Go back to the original '
 'equation, 3x + 7 = 22, and plug in 5 for ‘x’. Does it work? (Pause for '
 'student to verify).  Great!\n'
 '\n'
 'Here are a few things to think about: What if the number on the right side '
 'of the equation (the ‘22’) was different? How would that chang

In [6]:
PatchDPOTrainer()

dpo_trainer = DPOTrainer(
    model = model,
    args = training_args.dpo_config,
    beta = training_args.beta,
    train_dataset = datasets['train'],
    eval_dataset = datasets['test'],
    tokenizer = tokenizer,
    max_length = training_args.max_length,
    max_prompt_length = training_args.max_prompt_length,
)

Extracting prompt in train dataset (num_proc=2):   0%|          | 0/291 [00:00<?, ? examples/s]

Applying chat template to train dataset (num_proc=2):   0%|          | 0/291 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=2):   0%|          | 0/291 [00:00<?, ? examples/s]

Extracting prompt in eval dataset (num_proc=2):   0%|          | 0/37 [00:00<?, ? examples/s]

Applying chat template to eval dataset (num_proc=2):   0%|          | 0/37 [00:00<?, ? examples/s]

Tokenizing eval dataset (num_proc=2):   0%|          | 0/37 [00:00<?, ? examples/s]

In [7]:
dpo_trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 291 | Num Epochs = 3 | Total steps = 108
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 5,636,096/1,000,000,000 (0.56% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,rewards / chosen,rewards / rejected,rewards / accuracies,rewards / margins,logps / chosen,logps / rejected,logits / chosen,logits / rejected,eval_logits / chosen,eval_logits / rejected,nll_loss,aux_loss
1,0.6931,0.0,0.0,0.0,0.0,-995.037292,-1251.57251,0.219984,0.090849,0,0,0,0
2,0.6931,0.0,0.0,0.0,0.0,-970.04071,-978.204468,0.654421,0.837312,No Log,No Log,No Log,No Log
3,0.6858,0.010172,-0.004564,0.75,0.014735,-1274.478882,-1199.594482,-0.189422,-0.29671,No Log,No Log,No Log,No Log
4,0.6911,-0.001135,-0.005249,0.625,0.004114,-918.644897,-910.278381,-0.008415,0.141742,No Log,No Log,No Log,No Log
5,0.6857,0.005532,-0.009537,0.875,0.015069,-1227.593262,-1427.447266,-0.117358,-0.403698,No Log,No Log,No Log,No Log
6,0.6838,0.00921,-0.009696,1.0,0.018906,-1098.032471,-1251.532349,0.394746,0.532322,No Log,No Log,No Log,No Log
7,0.6685,0.015222,-0.034704,1.0,0.049926,-1099.028442,-1292.878418,0.207931,-0.298634,No Log,No Log,No Log,No Log
8,0.6819,0.015794,-0.007064,0.75,0.022858,-1278.541992,-1631.342773,0.468701,0.515837,No Log,No Log,No Log,No Log
9,0.6748,0.020327,-0.016894,1.0,0.037221,-1333.130249,-1403.849121,0.107704,-0.023202,No Log,No Log,No Log,No Log
10,0.6614,0.015154,-0.049842,1.0,0.064996,-1106.678223,-1211.460815,0.475424,0.277084,No Log,No Log,No Log,No Log


TrainOutput(global_step=108, training_loss=0.3900912070163974, metrics={'train_runtime': 1113.3209, 'train_samples_per_second': 0.784, 'train_steps_per_second': 0.097, 'total_flos': 0.0, 'train_loss': 0.3900912070163974, 'epoch': 2.9315068493150687})

In [8]:
model.save_pretrained(os.path.join(wrk_dir, "models/Llama-3.2-1B-dpo-edu"))
tokenizer.save_pretrained(os.path.join(wrk_dir, "models/Llama-3.2-1B-dpo-edu"))

('/content/drive/MyDrive/constitutional-ai-education/models/Llama-3.2-1B-dpo-edu/tokenizer_config.json',
 '/content/drive/MyDrive/constitutional-ai-education/models/Llama-3.2-1B-dpo-edu/special_tokens_map.json',
 '/content/drive/MyDrive/constitutional-ai-education/models/Llama-3.2-1B-dpo-edu/tokenizer.json')