This notebook, is looking at way to compare the logits, logprobs, and logs of chosen completions. We hope to find a way this is more information rich than accuracy and can differentiate between the models better

it should
- be like accuracy (0 to 1)
- be able to differentiate between models well

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass, field
from datasets import Dataset, features
import tempfile
from trl import DPOConfig, DPOTrainer
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt


from open_pref_eval.trainer import dummy_dataset, OPEConfig
from open_pref_eval.evaluation import eval_dpo_dataset

In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import List, Tuple


In [4]:
# custom trainer that return full logits, so I can practice....
class OPETrainer2(DPOTrainer):

    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        if logits.shape[:-1] != labels.shape:
            raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

        if not is_encoder_decoder:
            labels = labels[:, 1:].clone()
            logits = logits[:, :-1, :]
        loss_mask = labels != label_pad_token_id

        # dummy token; we'll ignore the losses on these tokens later
        labels[labels == label_pad_token_id] = 0

        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

        return per_token_logps, loss_mask


    def concatenated_forward(
        self, model, batch):

        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            is_vision_model=self.is_vision_model,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = {}

        if self.is_encoder_decoder:
            model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
            model_kwargs["decoder_input_ids"] = concatenated_batch.pop("concatenated_decoder_input_ids", None)

        if self.is_vision_model:
            model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
            model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]

        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        outputs = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        )
        all_logits = outputs.logits

        all_logps, mask = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            # average_log_prob=self.loss_type == "ipo",
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_mask = mask[:len_chosen]
        rejected_mask = mask[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_mask, rejected_mask)
    # concatenated_forward

def get_dummy_trainer2(model=None, tokenizer=None, model_name:Optional[str]=None, per_device_eval_batch_size=8, **kwargs):
    """
    Make a dummy trainer, 

    For keyword arguments, see 
    - [transformers.TrainingArguments](https://huggingface.co/docs/transformers/v4.43.3/en/main_classes/trainer#transformers.TrainingArguments)
    - [trl.DPOConfig](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)

    """
    with tempfile.TemporaryDirectory() as tmp_dir:
        training_args = OPEConfig(
            output_dir=tmp_dir,
            per_device_eval_batch_size=per_device_eval_batch_size,
            loss_type='dpo',
            **kwargs
        )

    if model_name is not None:
        model = AutoModelForCausalLM.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    if model is None:
        raise ValueError('model or model_name must be provided')

    # we rse a TRL class
    trainer = OPETrainer2(
        model=model,
        ref_model=None,
        args=training_args,
        tokenizer=tokenizer,
        train_dataset=dummy_dataset,
        eval_dataset=dummy_dataset,
    )
    return trainer




In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM, get_peft_model, PeftConfig, PeftModelForCausalLM
model_name = "gepardzik/LLama-3-8b-rogue-lora"
peft_config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(model_name)
model = PeftModelForCausalLM.from_pretrained(
    base_model,
    model_name, config=peft_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [6]:
trainer = get_dummy_trainer2(model=model, tokenizer=tokenizer, per_device_eval_batch_size=6)
trainer

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

<__main__.OPETrainer2 at 0x7ce044f2c890>

In [7]:
from open_pref_eval.evaluation import eval_dpo_dataset, load_dataset, ds2name, is_peft_model, is_peft_model, set_adapter
from trl import DPOTrainer
from typing import Optional, List, Union
from datasets import Dataset
from contextlib import contextmanager, nullcontext
from tqdm.auto import tqdm


def extract_logps(trainer, model, batch, step):
    bs = batch['chosen_input_ids'].shape[0]
    i = bs * step + torch.arange(bs)
    
    (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_mask, rejected_mask) = trainer.concatenated_forward(model, batch)

    # Note: if we are using ipo or reprpo this will be adjusted for length, but otherwise not which would bias the results
    # logratio = chosen_logps-rejected_logps

    # turn into list of dicts
    n = dict(
        # logps
        # _logratio=logratio.detach().cpu().float().numpy(),
        _chosen_logps=chosen_logps.detach().cpu().float(),
        _rejected_logps=rejected_logps.detach().cpu().float(),

        # masks
        _chosen_mask=chosen_mask.detach().cpu().float(),
        _rejected_mask=rejected_mask.detach().cpu().float(),

        # completion length, for checking if the model is biased
        # _l_chosen=(batch['chosen_labels']>0).sum(-1).detach().cpu().float().numpy(),
        # _l_rejected=(batch['rejected_labels']>0).sum(-1).detach().cpu().float().numpy(),

        # metadata
        ds_i=i.numpy(),

                    
    )
    return [dict(
        model=trainer.model.config._name_or_path,
        # arrays
        **{k:v[i] for k,v in n.items()}
    ) for i in range(bs)]

@torch.no_grad()
def eval_dpo_dataset(trainer: DPOTrainer, dataset: Union[Dataset,str]):
    """
    We eval the prob_chosen/prob_rejected for each sample in the dataset (per token)

    Must have cols chosen and rejected just like the trl dpotrainer

    see trainer.evaluation_loop
    """
    if isinstance(dataset, str):
        dataset_name, split = dataset.split('#')
        dataset = load_dataset(dataset_name, split=split, keep_in_memory=False)

    model = trainer.model
    model.eval()
    model.config.use_cache = False


    data = []
    # use hf dpo trainer to tokenizer, and make loader
    dataset2 = dataset.map(trainer.tokenize_row, num_proc=trainer.dataset_num_proc, writer_batch_size=10)
    eval_dataloader = trainer.get_eval_dataloader(dataset2)

    # HACK
    # assert trainer.loss_type == 'ipo', 'only ipo is supported, since it gives us the avg of logps, and is not biased by response length'
    
    compte_ref_context_manager = torch.cuda.amp.autocast if trainer._peft_has_been_casted_to_bf16 else nullcontext
    
    with compte_ref_context_manager():
        for step, batch in enumerate(tqdm(eval_dataloader, desc=f"Eval {ds2name(dataset)}")):
            # FIXME test
            # batch = trainer._prepare_inputs(batch)

            if is_peft_model(model):
                # if model has peft adapters loop through them
                adapters = [None] +list(model.peft_config.keys())
                for adapter_name in adapters:
                    with set_adapter(model, adapter_name):
                        d = extract_logps(trainer, model, batch, step)
                        for dd in d:
                            dd['adapter'] = adapter_name if adapter_name is not None else 'base'
                            data.append(dd)
            else:
                data += extract_logps(trainer, model, batch, step)

    df = pd.DataFrame(data)

    df['dataset'] = ds2name(dataset)
    return df

In [14]:
# dataset2 = dummy_dataset.map(trainer.tokenize_row, num_proc=trainer.dataset_num_proc, writer_batch_size=10)
# eval_dataloader = trainer.get_eval_dataloader(dataset2)
# next(iter(eval_dataloader))

In [9]:
# get dummy trainer
trainer

df_r = eval_dpo_dataset(trainer, dummy_dataset)
# concatenated_forward

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

Eval None None :   0%|          | 0/1 [00:00<?, ?it/s]

In [15]:
df_r['adapter']

0        base
1        base
2        base
3        base
4        base
5        base
6        base
7        base
8        base
9     default
10    default
11    default
12    default
13    default
14    default
15    default
16    default
17    default
Name: adapter, dtype: object

In [18]:
def agg(logp, mask):
    return (logp * mask).sum(-1)/ mask.sum(-1)

def score(logp_c, logp_r, mask_c, mask_r):
    logp_c_agg = agg(logp_c, mask_c)
    logp_r_agg = agg(logp_r, mask_r)
    return (logp_c_agg / (logp_c_agg + logp_r_agg)).mean()

# so now we have the logps for each token, 
for adapter, g, in df_r.groupby('adapter'):
    logp_c = torch.stack(list(g['_chosen_logps'].values))
    logp_r = torch.stack(list(g['_rejected_logps'].values))
    mask_c = torch.stack(list(g['_chosen_mask'].values))
    mask_r = torch.stack(list(g['_rejected_mask'].values))
    s = score(logp_c, logp_r, mask_c, mask_r)
    print(f'{adapter}: {s:.4f}')

base: 0.4928
default: 0.4897
