# DPO Training

**Complete implementation guide**

## Training Pipeline

```
1. Load SFT model (this becomes the reference)
2. Create trainable copy (this becomes the policy)
3. Load preference dataset
4. Training loop:
   a. Compute log probs for chosen/rejected under both models
   b. Compute DPO loss
   c. Update policy
5. Save aligned model
```

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
import copy
from tqdm import tqdm
from dataclasses import dataclass

@dataclass
class DPOConfig:
    """Configuration for DPO training."""
    model_name: str = "gpt2"
    beta: float = 0.1
    learning_rate: float = 1e-6
    batch_size: int = 4
    num_epochs: int = 1
    max_length: int = 512
    warmup_steps: int = 100
    max_grad_norm: float = 1.0
    label_smoothing: float = 0.0

config = DPOConfig()
print("DPO Configuration:")
for k, v in vars(config).items():
    print(f"  {k}: {v}")

## Dataset Class

In [None]:
class DPODataset(Dataset):
    """Dataset for DPO training."""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize chosen
        chosen_tokens = self.tokenizer(
            item['chosen'],
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Tokenize rejected
        rejected_tokens = self.tokenizer(
            item['rejected'],
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'chosen_input_ids': chosen_tokens['input_ids'].squeeze(0),
            'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(0),
            'rejected_input_ids': rejected_tokens['input_ids'].squeeze(0),
            'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(0),
        }

## Computing Sequence Log Probabilities

In [None]:
def get_sequence_log_probs(
    model,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor
) -> torch.Tensor:
    """
    Compute total log probability of a sequence.
    """
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    
    # Shift for next-token prediction
    shift_logits = logits[:, :-1, :]
    shift_labels = input_ids[:, 1:]
    shift_mask = attention_mask[:, 1:]
    
    # Compute log probs
    log_probs = F.log_softmax(shift_logits, dim=-1)
    
    # Gather log probs for actual tokens
    token_log_probs = torch.gather(
        log_probs,
        dim=-1,
        index=shift_labels.unsqueeze(-1)
    ).squeeze(-1)
    
    # Mask and sum
    masked_log_probs = token_log_probs * shift_mask
    sequence_log_probs = masked_log_probs.sum(dim=-1)
    
    return sequence_log_probs

## DPO Training Loop

In [None]:
def train_dpo(policy_model, reference_model, train_loader, config, device):
    """Complete DPO training loop."""
    
    optimizer = torch.optim.AdamW(
        policy_model.parameters(),
        lr=config.learning_rate
    )
    
    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps
    )
    
    policy_model.train()
    reference_model.eval()
    
    for epoch in range(config.num_epochs):
        epoch_metrics = {'loss': 0, 'accuracy': 0}
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Get log probs from policy
            policy_chosen_logps = get_sequence_log_probs(
                policy_model,
                batch['chosen_input_ids'],
                batch['chosen_attention_mask']
            )
            policy_rejected_logps = get_sequence_log_probs(
                policy_model,
                batch['rejected_input_ids'],
                batch['rejected_attention_mask']
            )
            
            # Get log probs from reference (no grad)
            with torch.no_grad():
                ref_chosen_logps = get_sequence_log_probs(
                    reference_model,
                    batch['chosen_input_ids'],
                    batch['chosen_attention_mask']
                )
                ref_rejected_logps = get_sequence_log_probs(
                    reference_model,
                    batch['rejected_input_ids'],
                    batch['rejected_attention_mask']
                )
            
            # Compute DPO loss
            chosen_logratios = policy_chosen_logps - ref_chosen_logps
            rejected_logratios = policy_rejected_logps - ref_rejected_logps
            
            logits = config.beta * (chosen_logratios - rejected_logratios)
            loss = -F.logsigmoid(logits).mean()
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.max_grad_norm)
            optimizer.step()
            scheduler.step()
            
            # Metrics
            accuracy = (logits > 0).float().mean()
            epoch_metrics['loss'] += loss.item()
            epoch_metrics['accuracy'] += accuracy.item()
            
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{accuracy.item():.2%}"
            })
        
        # End of epoch
        avg_loss = epoch_metrics['loss'] / len(train_loader)
        avg_acc = epoch_metrics['accuracy'] / len(train_loader)
        print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.2%}")
    
    return policy_model

## Running DPO Training

In [None]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load models
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Reference model (frozen)
reference_model = AutoModelForCausalLM.from_pretrained("gpt2")
for param in reference_model.parameters():
    param.requires_grad = False
reference_model.to(device)
reference_model.eval()

# Policy model (trainable)
policy_model = AutoModelForCausalLM.from_pretrained("gpt2")
policy_model.to(device)

print(f"Models loaded!")
print(f"  Reference model: frozen")
print(f"  Policy model: trainable")

In [None]:
# Load dataset
from datasets import load_dataset

raw_data = load_dataset("Anthropic/hh-rlhf", split="train")
raw_data = raw_data.select(range(500))  # Small subset for demo

dpo_dataset = DPODataset(raw_data, tokenizer, max_length=256)
train_loader = DataLoader(dpo_dataset, batch_size=config.batch_size, shuffle=True)

print(f"Dataset loaded: {len(dpo_dataset)} samples")

In [None]:
# Train (uncomment to run)
# policy_model = train_dpo(policy_model, reference_model, train_loader, config, device)

## Key Hyperparameters

| Parameter | Typical Value | Notes |
|-----------|---------------|-------|
| `beta` | 0.1 | Lower = closer to reference |
| `learning_rate` | 1e-6 | Very low (like RLHF) |
| `epochs` | 1-3 | Avoid overfitting |
| `label_smoothing` | 0-0.1 | Regularization |

## Next Steps

Now let's explore advanced topics: memory optimization, hyperparameter tuning, evaluation, and common pitfalls.