Trying to modify hf dpo to work with the repos hypothesis...

see
- https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth
- https://gist.github.com/alvarobartt/9898c33eb3e9c7108d9ed2330f12a708

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer


In [None]:
# FIXME: we are meant to SFT first, so that the preferences are in sample but 1) if this works it might not be needed, and 2) this can be added later, if it works

## load the model

In [None]:
!pip install flash-attn --no-build-isolation --no-deps -q

In [None]:

model_name = "NousResearch/Meta-Llama-3-8B-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    load_in_4bit=True,
    attn_implementation="flash_attention_2",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False

# Load the adapter.
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.0,
    r=16,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
    ],
)
model = get_peft_model(model, peft_config)


## Train

### Modified classes


- record hidden states
- new loss

change
- get_batch_loss_metrics: to pass hs
- concatenated_forward to return hs
- dpo_loss to work diff

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange

def collect_hs(hs):
    """The residual stream or hs of the diff of the hs."""
    hs = rearrange(list(hs), "l b t h -> l b t h")
    return rearrange(hs, "l b t h -> b l t h")

def wmean(x, w):
    """weighted mean per neuron over batch."""
    w = w - w.min() + 0.1
    while w.dim() < x.dim():
        w = w.unsqueeze(-1)
    return (x * w).sum(0) / w.sum(0)

class ReprPOTrainer(DPOTrainer):
    """modified to optimise representations, that is hidden states not outputs."""


    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = (
            {
                "labels": concatenated_batch["concatenated_labels"],
                "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None),
            }
            if self.is_encoder_decoder
            else {}
        )
        outs = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            return_dict=True,
            **model_kwargs,
        )
        all_logits = outs.logits
        hs = outs.hidden_states

        all_logps, size_completion = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            # average_log_prob=self.loss_type == "ipo",
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )
        chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]

        if self.loss_type == "ipo":
            all_logps = all_logps / size_completion

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        chosen_hs = hs[:len_chosen]
        rejected_hs = hs[len_chosen:]

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg, chosen_hs, rejected_hs)

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_chosen_logps_avg,
            policy_chosen_hs,
            policy_rejected_hs,
        ) = self.concatenated_forward(model, batch)

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if (
            "reference_chosen_logps" in batch
            and "reference_rejected_logps" in batch
            and self.args.rpo_alpha is not None
        ):
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                            _,
                            reference_chosen_hs,
                            reference_rejected_hs,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                        _,
                        reference_chosen_hs,
                        reference_rejected_hs,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.reprpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_hs,
            reference_chosen_logps,
            reference_rejected_logps,
            reference_chosen_hs,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        if self.args.rpo_alpha is not None:
            losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()

        return losses.mean(), metrics
    

    def reprpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        policy_chosen_hs: torch.FloatTensor,
        policy_rejected_hs: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_chosen_hs: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        if self.reference_free:
            ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
        else:
            ref_logratios = reference_chosen_logps - reference_rejected_logps

        pi_logratios = pi_logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = pi_logratios - ref_logratios
        weighting = logits.exp()
        print('weighting', weighting)

        # FIXME we also need to weight by how rigth each sample was


        # mean of bad repr should be more similar to the mean of good behavior
        loss_rr = F.smooth_l1_loss(
            policy_rejected_hs,
            reference_chosen_hs,
            )
        # This loss says the good repr should be retained, weighted by how good this samples was
        loss_retain = F.smooth_l1_loss(
            policy_chosen_hs,
            reference_chosen_hs,
            )
        losses = loss_rr + loss_retain# * self.alpha

        chosen_rewards = (
            self.beta
            * (
                policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
            ).detach()
        )
        rejected_rewards = (
            self.beta
            * (
                policy_rejected_logps.to(self.accelerator.device)
                - reference_rejected_logps.to(self.accelerator.device)
            ).detach()
        )

        return losses, chosen_rewards, rejected_rewards

## Dataset

In [27]:
# load dataset
# https://github.dev/eric-mitchell/direct-preference-optimization/preference_datasets.py
# dataset = load_dataset('HuggingFaceH4/stack-exchange-preferences') # 22GB
# dataset = load_dataset("stanfordnlp/shp") # 1GB
dataset = load_dataset("when2rl/SHP_reformatted")
# dataset = load_dataset("RLHFlow/SHP-standard") # 1GB
# dataset = load_dataset('unalignment/toxic-dpo-v0.1') # this should give a a bigger difference, since it's aligned opposite this even before release
dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected', 'other_info'],
        num_rows: 348718
    })
    validation: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected', 'other_info'],
        num_rows: 18436
    })
    test: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected', 'other_info'],
        num_rows: 18409
    })
})

In [34]:
# now we need to apply the tokeniser
def format_ds(row):
    
    return {
        "chosen": tokenizer.apply_chat_template(row['chosen'], tokenize=False),
        "rejected": tokenizer.apply_chat_template(row['rejected'], tokenize=False),
    }


dataset2 = dataset.map(format_ds)
dataset2


DatasetDict({
    train: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected', 'other_info'],
        num_rows: 348718
    })
    validation: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected', 'other_info'],
        num_rows: 18436
    })
    test: Dataset({
        features: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected', 'other_info'],
        num_rows: 18409
    })
})

### Run

In [37]:

training_args = DPOConfig(
    num_train_epochs=3,
    learning_rate=5e-07,
    per_device_train_batch_size=1,
    do_eval=True,
    per_device_eval_batch_size=1,
    adam_epsilon=1e-08,
    lr_scheduler_type="linear",
    warmup_ratio=0.1,
    seed=42,
    logging_steps=100,
    save_steps=500,
    save_strategy="steps",
    output_dir="./output-dir",
    gradient_checkpointing=True,
    bf16=True,
    remove_unused_columns=False,
)

dpo_trainer = DPOTrainer(
    model,
    ref_model=None,
    args=training_args,
    beta=training_args.beta,
    train_dataset=dataset2["train"],
    eval_dataset=dataset2["test"],
    tokenizer=tokenizer,
    max_length=training_args.max_length,
    max_prompt_length=training_args.max_prompt_length,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Map:   0%|          | 0/348718 [00:00<?, ? examples/s]