<a href="https://colab.research.google.com/github/ritwikraha/nanoRL/blob/experiments/notebooks/dpo_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
!pip install transformers datasets accelerate -qq

In [33]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm import tqdm
import json

In [34]:
# TODO-aritra: replace this with reasonibg dataset
sample_data = [
    {
        "prompt": "What is the capital of France?",
        "chosen": " The capital of France is Paris.",
        "rejected": " The capital of France is Berlin.",
    },
    {
        "prompt": "Explain Newton's second law.",
        "chosen": " Newton's second law states that force equals mass times acceleration.",
        "rejected": " Newton's second law says gravity makes things fall.",
    },
]

with open("dpo_data.json", "w") as f:
    json.dump(sample_data, f, indent=2)

In [35]:
class PreferenceDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        prompt = sample["prompt"]
        chosen = sample["chosen"]
        rejected = sample["rejected"]

        chosen_enc = self.tokenizer(
            prompt + chosen,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
        )
        rejected_enc = self.tokenizer(
            prompt + rejected,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
        )

        return {
            "chosen_input_ids": chosen_enc["input_ids"].squeeze(0),
            "chosen_attention_mask": chosen_enc["attention_mask"].squeeze(0),
            "rejected_input_ids": rejected_enc["input_ids"].squeeze(0),
            "rejected_attention_mask": rejected_enc["attention_mask"].squeeze(0),
        }

In [36]:
def get_logps(model, input_ids, attention_mask):
    """
    Computes the sum of log-probabilities of the predicted tokens,
    ignoring padding tokens.

    Parameters:
    - model: Causal language model (AutoModelForCausalLM)
    - input_ids: token ids [batch_size, seq_len]
    - attention_mask: binary mask for padded tokens

    Returns:
    - Sum of log-probs per sequence: [batch_size]
    """
    # Get logits from model forward pass
    with torch.no_grad() if not model.training else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch_size, seq_len, vocab_size]

    # shift inputs and logits to align predictions with correct labels
    shift_logits = logits[..., :-1, :].contiguous()  # [B, L-1, V]
    shift_labels = input_ids[..., 1:].contiguous()  # [B, L-1]
    shift_mask = attention_mask[..., 1:].contiguous()  # [B, L-1]

    # log-probs over vocabulary
    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)

    # gather log-probs of the correct labels
    selected_log_probs = torch.gather(
        log_probs, -1, shift_labels.unsqueeze(-1)
    ).squeeze(-1)

    # mask out padding tokens
    selected_log_probs = selected_log_probs * shift_mask

    # sum all thelog-probs per example
    return selected_log_probs.sum(dim=-1)  # [batch_size]

# Direct Preference Optimization (DPO) Loss

**Paper**: [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)

**Title**: "Direct Preference Optimization: Your Language Model is Secretly a Reward Model"

##  Formula

$L_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(x,y_w,y_l) \sim D} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w | x)}{\pi_{\text{ref}}(y_w | x)} - \beta \log \frac{\pi_\theta(y_l | x)}{\pi_{\text{ref}}(y_l | x)} \right) \right]$

Where:
- $\pi_\theta$ is the policy model being trained
- $\pi_{\text{ref}}$ is the reference model (frozen)
- $y_w$ is the chosen (winning) response
- $y_l$ is the rejected (losing) response
- $\beta$ is the temperature parameter
- $\sigma$ is the sigmoid function

## Alternative Formula (log space)

$L_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma \left( \beta \left( \log \pi_\theta(y_w | x) - \log \pi_\theta(y_l | x) - \log \pi_{\text{ref}}(y_w | x) + \log \pi_{\text{ref}}(y_l | x) \right) \right) \right]$

DPO directly optimizes the policy to prefer chosen responses over rejected ones, while staying close to the reference model through the implicit KL regularization term (controlled by β).

In [37]:
def dpo_loss(
    policy_chosen_logps,
    policy_rejected_logps,
    ref_chosen_logps,
    ref_rejected_logps,
    beta=0.1,
):
    """
    dpo loss formula:
    L = -E[log σ(β(log π(y_w|x) - log π(y_l|x) - log π_ref(y_w|x) + log π_ref(y_l|x)))]

    where y_w is chosen (preferred) and y_l is rejected
    σ is sigmoid, β controls kl regularization strength
    """
    # compute preference gaps
    policy_logits = policy_chosen_logps - policy_rejected_logps
    ref_logits = ref_chosen_logps - ref_rejected_logps

    # dpo loss: encourage policy_logits > ref_logits
    logits = beta * (policy_logits - ref_logits)
    loss = -F.logsigmoid(logits).mean()

    # kl divergence: kl(π || π_ref) = E[log π - log π_ref]
    kl_chosen = policy_chosen_logps - ref_chosen_logps
    kl_rejected = policy_rejected_logps - ref_rejected_logps
    kl_div = 0.5 * (kl_chosen.mean() + kl_rejected.mean())

    # additional metrics for monitoring training
    with torch.no_grad():
        # how often does policy prefer chosen over rejected?
        rewards_accuracy = (policy_logits > 0).float().mean()

        # average preference margins
        rewards_margin = policy_logits.mean()
        ref_rewards_margin = ref_logits.mean()

        # implicit reward from dpo paper: r(x,y) = β * log(π(y|x) / π_ref(y|x))
        implicit_rewards_chosen = beta * (policy_chosen_logps - ref_chosen_logps)
        implicit_rewards_rejected = beta * (policy_rejected_logps - ref_rejected_logps)
        implicit_rewards_margin = (
            implicit_rewards_chosen - implicit_rewards_rejected
        ).mean()

    metrics = {
        "rewards_accuracy": rewards_accuracy,
        "rewards_margin": rewards_margin,
        "ref_rewards_margin": ref_rewards_margin,
        "implicit_rewards_margin": implicit_rewards_margin,
        "kl_div": kl_div,
    }

    return loss, kl_div, metrics

In [38]:
def training_step(batch, model, ref_model, beta=0.1):
    # log-probs for policy model
    chosen_logps = get_logps(
        model, batch["chosen_input_ids"], batch["chosen_attention_mask"]
    )
    rejected_logps = get_logps(
        model, batch["rejected_input_ids"], batch["rejected_attention_mask"]
    )

    # log-probs for reference model (frozen)
    with torch.no_grad():
        ref_chosen_logps = get_logps(
            ref_model, batch["chosen_input_ids"], batch["chosen_attention_mask"]
        )
        ref_rejected_logps = get_logps(
            ref_model, batch["rejected_input_ids"], batch["rejected_attention_mask"]
        )

    # compute the dpo loss and metrics
    loss, kl, metrics = dpo_loss(
        policy_chosen_logps=chosen_logps,
        policy_rejected_logps=rejected_logps,
        ref_chosen_logps=ref_chosen_logps,
        ref_rejected_logps=ref_rejected_logps,
        beta=beta,
    )

    return loss, kl, metrics

In [39]:
# load models and respective tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# gpt2 doesn't have pad_token by default
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
ref_model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [40]:
# load dataset
with open("dpo_data.json") as f:
    data = json.load(f)

dataset = PreferenceDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
model.train()

for epoch in range(3):
    print(f"Epoch {epoch}")
    for batch in tqdm(dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        loss, kl, metrics = training_step(batch, model, ref_model, beta=0.1)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # log all metrics:
        # - Loss: How much the model needs to improve - should decrease

        # - KL Divergence: How much the model has changed from the original should stay stable

        # - Accuracy: Percentage of time model prefers chosen over rejected responses - should increase

        # - Policy Margin: How much current model prefers good responses over bad ones - should increase

        # - Reference Margin: How much original model prefers good responses over bad ones - should stay constant

        # - Implicit Reward: Internal reward score the model learns for chosen vs rejected - should increase

        print(
            f"Loss: {loss.item():.4f} | KL: {kl.item():.4f} | "
            f"Acc: {metrics['rewards_accuracy'].item():.3f} | "
            f"Policy Margin: {metrics['rewards_margin'].item():.4f} | "
            f"Ref Margin: {metrics['ref_rewards_margin'].item():.4f} | "
            f"Implicit Reward: {metrics['implicit_rewards_margin'].item():.4f}"
        )

Epoch 0


100%|██████████| 1/1 [01:27<00:00, 87.05s/it]


Loss: 0.8778 | KL: -9.8716 | Acc: 0.500 | Policy Margin: -2.1212 | Ref Margin: 1.0552 | Implicit Reward: -0.3176
Epoch 1


100%|██████████| 1/1 [00:51<00:00, 51.50s/it]


Loss: 0.6609 | KL: -8.4320 | Acc: 0.500 | Policy Margin: 1.7704 | Ref Margin: 1.0552 | Implicit Reward: 0.0715
Epoch 2


100%|██████████| 1/1 [00:50<00:00, 50.30s/it]

Loss: 0.9314 | KL: -9.3805 | Acc: 0.500 | Policy Margin: -2.2133 | Ref Margin: 1.0552 | Implicit Reward: -0.3268





In [31]:
model.save_pretrained("dpo-model")
tokenizer.save_pretrained("dpo-model")

('dpo-model/tokenizer_config.json',
 'dpo-model/special_tokens_map.json',
 'dpo-model/vocab.json',
 'dpo-model/merges.txt',
 'dpo-model/added_tokens.json',
 'dpo-model/tokenizer.json')