# DPO from Scratch

Adapted from https://arxiv.org/abs/2305.18290

In [None]:
#import packages
import argparse
import random
import numpy as np
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

import wandb
from tqdm import tqdm

In [None]:
def seed_everything(seed=2003):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def dpo_loss(model_pref_logp, model_dispref_logp, ref_pref_logp, ref_dispref_logp, beta=0.50):
  diff_pref_logp = model_pref_logp - ref_pref_logp
  diff_dispref_logp = model_dispref_logp - ref_dispref_logp
  loss = -F.logsigmoid(beta * (diff_pref_logp - diff_dispref_logp)).mean()
  reward_margins = (diff_pref_logp - diff_dispref_logp).mean()
  return loss, diff_pref_logp.mean(), diff_dispref_logp.mean(), reward_margins


In [None]:
def get_logp(policy, token_ids, attention_mask, prompt_lengths):
  logits = policy(
          input_ids=token_ids,
          attention_mask=attention_mask
      ).logits #shape (batch_size, sequence_length, vocab_size)

  log_probs = F.log_softmax(logits, dim=-1)
  token_log_probs = torch.gather(log_probs, -1, token_ids.unsqueeze(-1)).squeeze(-1)
  batch_size, seq_len = token_ids.shape
  response_mask = torch.arange(seq_len, device=token_ids.device).unsqueeze(0) >= prompt_lengths.unsqueeze(1)
  response_mask = response_mask.float()
  response_sum_log_probs = (token_log_probs * response_mask).sum(dim=-1)
  response_lengths = response_mask.sum(dim=-1).clamp(min=1)
  return response_sum_log_probs / response_lengths

In [None]:
def collate_fn(batch):
    """
    Collate function for DPO dataset.
    Returns raw strings only (no tokenization here).
    """
    return {
        "prompts": ['Instruct: ' + item['prompt'] + '\n' for item in batch],
        "chosen": ['Output: ' + item['chosen'] for item in batch],
        "rejected": ['Output: ' + item['rejected'] for item in batch],
    }


In [None]:
class DPOTrainer:
  def __init__(self,
                dataset_name="HuggingFaceH4/ultrachat_200k",
                split="train",
                policy_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                ref_name=None,
                reward_name="OpenAssistant/reward-model-deberta-v3-base",
                output_dir="./dpo-policy",
                epochs=1, batch_size=2,
                max_length= 128,
                lr=1e-5,
                beta=0.2,
                seed=42, device=None):
      set_seed(seed)
      self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
      self.epochs = epochs
      self.beta = beta
      self.max_length = max_length
      self.output_dir = output_dir

      self.tokenizer = AutoTokenizer.from_pretrained(policy_name, padding_side="left", use_fast=True, model_max_length=512)
      if self.tokenizer.pad_token is None:
          self.tokenizer.pad_token = self.tokenizer.eos_token

      self.policy = AutoModelForCausalLM.from_pretrained(policy_name).to(self.device)

      self.ref = AutoModelForCausalLM.from_pretrained(ref_name).to(self.device)
      for p in self.ref.parameters():
          p.requires_grad_(False)

      self.opt = torch.optim.AdamW(self.policy.parameters(), lr=lr)

      self.ds = load_dataset(dataset_name, split=split)
      self.dl = DataLoader(self.ds, batch_size=batch_size, shuffle=True,
                            collate_fn=collate_fn)

  #performs a single iteration of DPO
  def step(self, batch):
    prompts = batch["prompts"]
    pref_responses = batch["chosen"]
    dispref_responses = batch["rejected"]

    #tokenize prompts
    prompt_encodings = self.tokenizer(
        prompts,
        padding='max_length',
        truncation=True,
        max_length=self.max_length,
        return_tensors='pt'
    )
    prompt_lengths = prompt_encodings.attention_mask.sum(dim=-1).to(device)

    #tokenize chosen reponses
    chosen_encodings = self.tokenizer(
        pref_responses,
        padding='max_length',
        truncation=True,
        max_length=self.max_length,
        return_tensors='pt'
    )

    prompt_preferred_ids = torch.cat([
        prompt_encodings.input_ids,
        chosen_encodings.input_ids
    ], dim=-1).to(self.device)


    prompt_preferred_mask = torch.cat([
        prompt_encodings.attention_mask,
        chosen_encodings.attention_mask
    ], dim=-1).to(self.device)

    #tokenize rejected responses
     rejected_encodings = self.tokenizer(
        dispref_responses,
        padding='max_length',
        truncation=True,
        max_length=self.max_length,
        return_tensors='pt'
    )

    prompt_dispreferred_ids = torch.cat([
        prompt_encodings.input_ids,
        rejected_encodings.input_ids
    ], dim=-1).to(self.device)


    prompt_dispreferred_mask = torch.cat([
        prompt_encodings.attention_mask,
        rejected_encodings.attention_mask
    ], dim=-1).to(self.device)


    #get model log prob for preferred and dispreferred responses
    policy_pref_logp = get_logp(self.policy, prompt_preferred_ids, prompt_preferred_mask, prompt_lengths)
    policy_dispref_logp = get_logp(self.policy, prompt_dispreferred_ids, prompt_dispreferred_mask, prompt_lengths)

    #get reference log prob for preferred and dispreferred responses
    with torch.no_grad():
      ref_pref_logp = get_logp(self.ref, prompt_preferred_ids, prompt_preferred_mask, prompt_lengths)
      ref_dispref_logp = get_logp(self.ref, prompt_dispreferred_ids, prompt_dispreferred_mask, prompt_lengths)

    #compute loss
    loss, _, _, _ = dpo_loss(
                policy_pref_logp,
                policy_dispref_logp,
                ref_pref_logp,
                ref_dispref_logp,
                beta=self.beta
            )
    return {'loss': loss}

  def train(self):
    self.policy.train()
    self.ref.eval()
    step = 0

    for epoch in range(self.epochs):
        for batch in self.dl:
            stats = self.step(batch)
            step += 1
            if step % 5 == 0:
                print(
                    f"epoch {epoch} step {step} | "
                    f"loss {stats['loss']:.4f} | "
                )

    # Save final model
    if self.output_dir:
        os.makedirs(self.output_dir, exist_ok=True)
        self.policy.save_pretrained(self.output_dir)
        self.tokenizer.save_pretrained(self.output_dir)

In [None]:
trainer = DPOTrainer(
    dataset_name="jondurbin/truthy-dpo-v0.1",
    split="train",
    policy_name="EleutherAI/pythia-70m",
    ref_name="EleutherAI/pythia-70m",
    epochs=10, batch_size=8,
    max_length=128,
    lr=1e-6,
    beta=0.10,
)
trainer.train()
print('Saved to', trainer.output_dir)