**Description**: Showcase CAPPr for [DPO](https://arxiv.org/abs/2305.18290) metric
evaluation

Why? CAPPr caches the prompt (the prompt is identical for the preferred and dispreferred
responses), automatically batches (with a progress bar), and handles
tokenization-weirdness. You don't need to create a dataset and dataloader. Just pass in
raw strings.

Note: This is almost-purely for showcasing purposes. You can't actually use this to
train a model. To do that, I'd need to stay in torch land and not set up the model for
inference. Currently, CAPPr converts torch tensors to numpy arrays, disables gradient
computations, and sets the model in eval mode.

**Estimated run time**: ~10 sec.

In [1]:
from typing import Sequence

import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)

from cappr import Example
from cappr.huggingface.classify import log_probs_conditional_examples
from cappr.utils.classify import agg_log_probs

# Implementation

Let's start w/ Sebastian Raschka's DPO function from [this
notebook](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/04_preference-tuning-with-dpo/dpo-from-scratch.ipynb).

CAPPr will supply each of the arguments.

In [2]:
def compute_dpo_loss(
    model_chosen_logprobs: torch.Tensor,
    model_rejected_logprobs: torch.Tensor,
    reference_chosen_logprobs: torch.Tensor,
    reference_rejected_logprobs: torch.Tensor,
    beta: float = 0.1,
):
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logprobs: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logprobs: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logprobs: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logprobs: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        beta: Temperature parameter for the DPO loss; typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
        label_smoothing: conservativeness for DPO loss.

    Returns:
        A tuple of three tensors: (loss, chosen_rewards, rejected_rewards).
    """
    model_logratios = model_chosen_logprobs - model_rejected_logprobs
    reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs
    logits = model_logratios - reference_logratios

    # DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
    losses = -F.logsigmoid(beta * logits)

    # Optional values to track progress during training
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    # .mean() to average over the samples in the batch
    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()

Now for the CAPPr stuff

It makes sense to use `Example` objects here. It looks like this:

```python
Example(
    prompt=prompt,
    completions=(preferred_response, dispreferred_response),
)
```

In [3]:
def avg_log_prob(
    examples: Example | Sequence[Example],
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    batch_size: int = 2,
    batch_size_completions: int | None = None,
):
    log_probs = log_probs_conditional_examples(
        examples,
        (model, tokenizer),
        batch_size=batch_size,
        batch_size_completions=batch_size_completions,
    )
    return agg_log_probs(log_probs, func=np.mean)

In [4]:
def dpo(
    examples: Example | Sequence[Example],
    model: PreTrainedModel,
    model_ref: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    beta: float = 0.1,
    batch_size: int = 2,
    batch_size_completions: int | None = None,
):
    model_logprobs = avg_log_prob(
        examples, model, tokenizer, batch_size, batch_size_completions
    )
    model_ref_logprobs = avg_log_prob(
        examples, model_ref, tokenizer, batch_size, batch_size_completions
    )
    model_chosen_logprobs, model_rejected_logprobs = model_logprobs.T
    reference_chosen_logprobs, reference_rejected_logprobs = model_ref_logprobs.T
    return compute_dpo_loss(
        torch.from_numpy(model_chosen_logprobs),
        torch.from_numpy(model_rejected_logprobs),
        torch.from_numpy(reference_chosen_logprobs),
        torch.from_numpy(reference_rejected_logprobs),
        beta=beta,
    )

# Demo

Dummy models and data

## Load models

In [5]:
model_name = "gpt2"
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
model_name_ref = "openai-community/gpt2-medium"
model_ref: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
    model_name_ref, device_map="auto"
)

In [7]:
model.num_parameters()

124439808

In [8]:
model_ref.num_parameters()

354823168

In [9]:
print(model.device)

mps:0


In [10]:
print(model_ref.device)

mps:0


In [11]:
# warm up
_ = model(**tokenizer(["warm up"], return_tensors="pt").to(model.device))
_ = model_ref(**tokenizer(["warm up"], return_tensors="pt").to(model.device))

## Dummy preference data

In [12]:
preference_dataset = [
    # Tuples of: prompt, preferred_response, dispreferred_response
    ("Say yes", "Yes", "No way"),
    ("How useful is this demo?", "Not too useful lol", "It's amazing!"),
    ("For instruct models, format the string yourself", "Ok fine", "That's stupid!"),
    ("We'll just throw in raw strings", "k", "1 + 1 = 3"),
    ("There are 5 examples here", "Correct", "No, there are  number of examples")
]

## Run

In [13]:
examples = [
    Example(
        prompt=prompt,
        completions=(preferred_response, dispreferred_response),
    )
    for prompt, preferred_response, dispreferred_response in preference_dataset
]

In [14]:
loss, chosen_reward, rejected_reward = dpo(examples, model, model_ref, tokenizer)

conditional log-probs:   0%|          | 0/5 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


conditional log-probs:   0%|          | 0/5 [00:00<?, ?it/s]

In [15]:
loss, chosen_reward, rejected_reward

(tensor(0.6778, dtype=torch.float64),
 tensor(-0.1116, dtype=torch.float64),
 tensor(-0.4374, dtype=torch.float64))