<a href="https://colab.research.google.com/github/ritwikraha/nanoRL/blob/experiments/notebooks/dpo_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
!pip install transformers datasets accelerate -qq

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m90.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m71.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from tqdm import tqdm
import json

In [15]:
# TODO-aritra: replace this with reasonibg dataset
sample_data = [
    {
        "prompt": "What is the capital of France?",
        "chosen": " The capital of France is Paris.",
        "rejected": " The capital of France is Berlin."
    },
    {
        "prompt": "Explain Newton's second law.",
        "chosen": " Newton's second law states that force equals mass times acceleration.",
        "rejected": " Newton's second law says gravity makes things fall."
    }
]

with open("dpo_data.json", "w") as f:
    json.dump(sample_data, f, indent=2)


In [16]:
class PreferenceDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        prompt = sample["prompt"]
        chosen = sample["chosen"]
        rejected = sample["rejected"]

        chosen_enc = self.tokenizer(prompt + chosen, return_tensors="pt", truncation=True, padding="max_length", max_length=self.max_length)
        rejected_enc = self.tokenizer(prompt + rejected, return_tensors="pt", truncation=True, padding="max_length", max_length=self.max_length)

        return {
            "chosen_input_ids": chosen_enc["input_ids"].squeeze(0),
            "chosen_attention_mask": chosen_enc["attention_mask"].squeeze(0),
            "rejected_input_ids": rejected_enc["input_ids"].squeeze(0),
            "rejected_attention_mask": rejected_enc["attention_mask"].squeeze(0),
        }


In [17]:
def get_logps(model, input_ids, attention_mask):
    """
    Computes the sum of log-probabilities of the predicted tokens,
    ignoring padding tokens.

    Parameters:
    - model: Causal language model (AutoModelForCausalLM)
    - input_ids: token ids [batch_size, seq_len]
    - attention_mask: binary mask for padded tokens

    Returns:
    - Sum of log-probs per sequence: [batch_size]
    """
    # Get logits from model forward pass
    with torch.no_grad() if not model.training else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape: [batch_size, seq_len, vocab_size]

    # shift inputs and logits to align predictions with correct labels
    shift_logits = logits[..., :-1, :].contiguous()  # [B, L-1, V]
    shift_labels = input_ids[..., 1:].contiguous()   # [B, L-1]
    shift_mask = attention_mask[..., 1:].contiguous()  # [B, L-1]

    # log-probs over vocabulary
    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)

    # gather log-probs of the correct labels
    selected_log_probs = torch.gather(log_probs, -1, shift_labels.unsqueeze(-1)).squeeze(-1)

    # mask out padding tokens
    selected_log_probs = selected_log_probs * shift_mask

    # sum all thelog-probs per example
    return selected_log_probs.sum(dim=-1)  # [batch_size]

In [None]:
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
             ref_chosen_logps, ref_rejected_logps, beta=0.1):
    """
    Computes DPO loss as per the paper:
    L = -log σ[β(Δπ - Δπ_ref)]

    Also returns the KL divergence between policy and reference distributions:
    KL = 0.5 * [(π_c - π*_c)^2 + (π_r - π*_r)^2]

    Parameters:
    - policy_chosen_logps: log π(y_c | x)
    - policy_rejected_logps: log π(y_r | x)
    - ref_chosen_logps: log π*(y_c | x)
    - ref_rejected_logps: log π*(y_r | x)
    - beta: temperature scaling factor

    Returns:
    - loss: scalar DPO loss
    - approx_kl: scalar KL divergence (average over batch)
    """
    # compute preference gaps
    pi_diff = policy_chosen_logps - policy_rejected_logps
    ref_diff = ref_chosen_logps - ref_rejected_logps

    # dpo loss--> encourage π_diff > ref_diff
    logits = beta * (pi_diff - ref_diff)
    loss = -torch.nn.functional.logsigmoid(logits).mean()

    # approximate KL divergence between policy and reference (symmetric)
    kl_chosen = (policy_chosen_logps - ref_chosen_logps) ** 2
    kl_rejected = (policy_rejected_logps - ref_rejected_logps) ** 2
    approx_kl = 0.5 * (kl_chosen + kl_rejected).mean()

    return loss, approx_kl

In [18]:
def training_step(batch, model, ref_model, beta=0.1):
    # log-probs for policy model
    chosen_logps = get_logps(model, batch["chosen_input_ids"], batch["chosen_attention_mask"])
    rejected_logps = get_logps(model, batch["rejected_input_ids"], batch["rejected_attention_mask"])

    # log-probs for reference model (frozen)
    with torch.no_grad():
        ref_chosen_logps = get_logps(ref_model, batch["chosen_input_ids"], batch["chosen_attention_mask"])
        ref_rejected_logps = get_logps(ref_model, batch["rejected_input_ids"], batch["rejected_attention_mask"])

    # compute the dpo loss and KL-Divergence
    loss, kl = dpo_loss(
        policy_chosen_logps=chosen_logps,
        policy_rejected_logps=rejected_logps,
        ref_chosen_logps=ref_chosen_logps,
        ref_rejected_logps=ref_rejected_logps,
        beta=beta
    )

    return loss, kl


In [None]:
# load models and  respective tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# gpt2 doesn't have pad_token by default
# might not be needed for other models
tokenizer.pad_token = tokenizer.eos_token

# TODO-ritwik: make this generic
model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
ref_model.to(device)


In [None]:
# load dataset
with open("dpo_data.json") as f:
    data = json.load(f)

dataset = PreferenceDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-6)
model.train()

for epoch in range(3):
    print(f"Epoch {epoch}")
    for batch in tqdm(dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        loss, kl = training_step(batch, model, ref_model, beta=0.1)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Loss: {loss.item():.4f} | KL: {kl.item():.4f}")


In [None]:
model.save_pretrained("dpo-model")
tokenizer.save_pretrained("dpo-model")