In this notebook, we take the hidden states inside the mlp and attention side channels, and apply reroute and retain losses, in an attemp to improve learning and generalization compared to DPI. The idea is to force the model to learn to use the hidden states in the side channels, and to retain the information that is useful for the task at hand.

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_PROJECT"] = "repo-dpo" 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["WANDB_DISABLED"] = "true"



In [3]:
import wandb
os.environ['WANDB_NOTEBOOK_NAME'] =  os.path.basename(globals()['__vsc_ipynb_file__'])
nb_name = os.path.basename(globals()['__vsc_ipynb_file__']).replace('.ipynb', '').replace(' ', '_')
# enable wandb service (experimental, https://github.com/wandb/client/blob/master/docs/dev/wandb-service-user.md)
# this hopefully fixes issues with multiprocessing
wandb.require(experiment='service')

In [4]:
from reprpo import silence

In [5]:
import torch
import numpy as np
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange

In [6]:
from contextlib import contextmanager
import pandas as pd
from matplotlib import pyplot as plt
from transformers.trainer import ProgressCallback
from transformers.utils.notebook import NotebookProgressCallback

from reprpo.helpers.adapters import set_adapter

In [7]:
torch.set_float32_matmul_precision("medium")
# torch.use_deterministic_algorithms(True)

In [8]:
max_prompt_length=64
# num_samples = 50 * 16 * 6
num_samples = 1500 * 13 * 3 # from circuit breaker * 3
max_length = 128
num_samples

58500

## load the model

In [9]:
!pip install flash-attn --no-build-isolation --no-deps -qq

In [None]:
# model

In [None]:
# FIXME: we are meant to SFT first, so that the preferences are in sample but 1) if this works it might not be needed, and 2) this can be added later, if it works
# for now we will use the instruct model, and try something it wasn't meant to do but it in sample 
model_name = "microsoft/Phi-3-mini-4k-instruct"
# model_name = "NousResearch/Meta-Llama-3-8B-Instruct"
# model_name = "microsoft/Phi-3-mini-4k-instruct-gguf"
# model_name = "NousResearch/Meta-Llama-3.1-8B-Instruct"

use_gradient_checkpointing = False

from peft.tuners import BOFTConfig, OFTConfig, LoraConfig, IA3Config
## Big adapter
# peft_config = OFTConfig(
#     r=4,
#     task_type="CAUSAL_LM",
#     target_modules=["qkv_proj", "down_proj",
#                     "o_proj", "gate_up_proj",
#                     ],
# )


"""
# rescale
Infused Adapter by Inhibiting and Amplifying Inner Activations, or IA3, is a method that adds three learned vectors to rescale the keys and values of the self-attention and encoder-decoder attention layers, and the intermediate activation of the position-wise feed-forward network."""
# peft_config = IA3Config(
#     # r=4,
#     # task_type="CAUSAL_LM",
#     target_modules=["qkv_proj", "down_proj",
#                     "o_proj", "gate_up_proj",
#                     ],
#     feedforward_modules=["gate_up_proj", "down_proj"]
# )
peft_config = LoraConfig(
    lora_alpha=16, 
    r=16,
    # lora_dropout=0.05,
    use_rslora=True,
    # use_dora=True,
    task_type="CAUSAL_LM",
    target_modules=[
        "qkv_proj", "gate_up_proj", # in
        "down_proj",  "o_proj", # out
                    
                    ],
    # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
from reprpo.models.load import load_model, print_trainable_parameters




from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from trl.trainer.utils import peft_module_casting_to_bf16

model, tokenizer = load_model(model_name, bnb=True )

if use_gradient_checkpointing:
    model.enable_input_require_grads()
peft_module_casting_to_bf16(model)
adapter_name='ReprPO'
model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': use_gradient_checkpointing})
model = get_peft_model(model, peft_config, adapter_name=adapter_name)
print_trainable_parameters(model)
# phi model def https://github.com/huggingface/transformers/blob/14ee2326e51cb210cec72f31b248cb722e9d5d1f/src/transformers/models/phi/modeling_phi.py#L748
# 2 paths, mlp, and attn
model

- gate up  \
- gate down/
- 

## Dataset

In [None]:
def sample(dataset, N):
    return (dataset
            .shuffle(42)
            .select(range(
            min(len(dataset), N)))
    )

In [None]:
# dataset = load_dataset('Columbia-NLP/DPO-HelpSteer')
dataset = load_dataset('Atsunori/HelpSteer2-DPO').map(lambda x: {
    'prompt': x['prompt']+ ' '})
dataset['train'] = sample(dataset['train'], num_samples)
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')
dataset2['train'][0]

In [None]:
# def format_ds(row):
    
#     # WHY are we doing this? Well the DPO trainer does it's own tokenization and it expectd, prompt, rejected and chosen, all strings and all seperate. Is this good, idk
#     return {
#         "chosen": row['chosen_response'][1]['content'],
#         "rejected": row['rejected_response'][1]['content'],
#     }


# dataset2 = dataset.map(format_ds)


In [None]:
r = dataset2['train'][0]
print(r['prompt'])
print('===')
print(r['chosen'])
print('---')
print(r['rejected'])

In [None]:
# from reprpo.eval.mc import eval_tqa
from reprpo.gen import generation_test

## Train

### Modified classes

- here we can defined the experimetns loss function

In [None]:
# # modified from https://github.com/huggingface/peft/blob/4611034ff85e26e1f9647ea1f505e9e50322ff0f/src/peft/peft_model.py#L1005
# keys = [
#     "qkv_proj",
#     "o_proj",
# ]
# prefix = "base_model.model."
# for key in keys:
#     suffix_pos = key.rfind(".")
#     extended_prefix = prefix + key[:suffix_pos]
#     module = dict(model.named_modules())[extended_prefix]
# model.named_modules
# model.get_submodule

In [None]:
from dataclasses import dataclass

@dataclass
class ReprPOConfig2(DPOConfig):
    alpha: int = 1
    print_every: int = 10
    collection_layers: tuple = (10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
    collection_keys: tuple = ('base_model.model.model.layers.{layer}.self_attn.o_proj', 'base_model.model.model.layers.{layer}.mlp.down_proj')

    # NOTE to self, do not pass both peft_config and model_adapter_name. peft_config creates a new adapter

In [None]:
def check_training_args(training_args, model):
    assert training_args.collection_layers is not None
    assert training_args.collection_keys is not None
    ps = [p.format(layer=22) for p in training_args.collection_keys]
    print(ps)
    ps = [model.get_submodule(p) for p in ps]
    # print(ps)
    return training_args

training_args = ReprPOConfig2('')
check_training_args(training_args, model);

In [None]:
import itertools

In [None]:
# training_args

# layer_paths = [
#     [p.format(layer=layer) for p in training_args.collection_keys]
#     for layer in training_args.collection_layers
# ]
# layer_paths = list(itertools.chain(*layer_paths))
# layer_paths

In [None]:
from reprpo.trainer import ReprPOTrainer, ReprPOConfig, mean_with_attention, normalize_output
from baukit.nethook import TraceDict
from einops import repeat
from jaxtyping import Float
from torch import Tensor

def mean_with_attention(x: Float[Tensor, 'b t h'], attn_mask: Float[Tensor, 'b t'], dim: int = 1) -> Float[Tensor, 'b h']:
    """mean of x, weighted by the attention mask, over dim (token or batch)"""
    layer_attn_mask = repeat(attn_mask, 'b t -> b t h', h=1).detach()
    return (x * layer_attn_mask).sum(dim) / layer_attn_mask.sum(dim)

def detach_hsd(hs):
    return {k: v.detach() for k, v in hs.items()}

class ReprPOTrainer2(ReprPOTrainer):

    def __init__(self, args: Optional[ReprPOConfig] = None, **kwargs):
        DPOTrainer.__init__(self, args=args, **kwargs)
        self.collection_layers = args.collection_layers
        self.alpha = args.alpha
        self.loss_type = 'ipo'

        self.num_training_steps = self.args.max_steps
        if self.num_training_steps==-1:
            self.num_training_steps = self.args.num_train_epochs * len(self.get_train_dataloader()) // self.args.gradient_accumulation_steps


        self.layer_paths = [
            [p.format(layer=layer) for p in args.collection_keys]
            for layer in self.collection_layers
        ]
        self.layer_paths = list(itertools.chain(*self.layer_paths))
        print(self.layer_paths)


    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        model.eval()
        # with model.disable_adapter():
        with torch.no_grad():
            with self.null_ref_context():
                (
                    ref_chosen_logps,
                    ref_rejected_logps,
                    _,
                    _,
                    _,
                    ref_chosen_hs,
                    ref_rejected_hs,
                    _,
                    _
                ) = self.concatenated_forward(self.model, batch)
        ref_chosen_hs = detach_hsd(ref_chosen_hs)
        ref_rejected_hs = detach_hsd(ref_rejected_hs)
        # ref_chosen_hs = {k:v.detach() for k,v in ref_chosen_hs.items()}
        # ref_rejected_hs = {k:v.detach() for k,v in ref_rejected_hs.items()}
        ref_chosen_logps = ref_chosen_logps.detach()
        ref_rejected_logps = ref_rejected_logps.detach()

        model.train()
        (
            pi_chosen_logps,
            pi_rejected_logps,
            _,
            _,
            pi_chosen_logps_avg,
            pi_chosen_hs,
            pi_rejected_hs,
            chosen_attn_mask,
            rejected_attn_mask
        ) = self.concatenated_forward(model, batch)

        loss, loss_info = self.reprpo_loss(
            pi_chosen_logps,
            pi_rejected_logps,
            pi_chosen_hs,
            pi_rejected_hs,
            ref_chosen_logps,
            ref_rejected_logps,
            ref_chosen_hs,
            ref_rejected_hs,
            chosen_attn_mask,
            rejected_attn_mask
        )
        # losses, chosen_rewards, rejected_rewards, loss_retain, loss_rr = loss_info
        chosen_rewards, rejected_rewards = (
            loss_info["chosen_rewards"],
            loss_info["rejected_rewards"],
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        if self.args.rpo_alpha is not None:
            loss = loss * self.args.rpo_alpha - pi_chosen_logps_avg


        prefix = "eval_" if train_eval == "eval" else ""
        
        # how often the policy model is better at choosing the right response
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        # how much the policy model is better
        metrics[f"{prefix}rewards/margins"] = (
            (chosen_rewards - rejected_rewards).mean().cpu()
        )

        # the log probability that the model would generate the tokens of the rejected string
        metrics[f"{prefix}logps/rejected"] = pi_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = pi_chosen_logps.detach().mean().cpu()


        for k in loss_info.keys():
            if '_' in k:
                a,b = k.split('_', 1)
                k2 = f"{b}/{a}"
            else:
                k2 = k
            v = loss_info[k]
            if isinstance(v, torch.Tensor):
                v = v.mean().detach().cpu().item()
            metrics[f"{prefix}{k2}"] = float(v)

        if self.state.global_step % self.args.print_every == 0:
            
            # TODO do this ok each key, then take the mean
            def cosine_on_keys(hs1, hs2):
                return torch.stack([F.cosine_similarity(hs1[k], hs2[k], dim=-1).mean() for k in hs1.keys()]).mean()
            retain_cosine = cosine_on_keys(pi_chosen_hs, ref_chosen_hs)
            rr_cosine = cosine_on_keys(pi_rejected_hs, ref_chosen_hs)
            # retain_cosine = F.cosine_similarity(
            #     pi_chosen_hs, ref_chosen_hs, dim=-1
            # ).mean()
            # rr_cosine = F.cosine_similarity(
            #     pi_rejected_hs, ref_chosen_hs, dim=-1
            # ).mean()
            print(
                self.state.global_step,
                f"retain_cos_sim: {retain_cosine:.4f}. rr_cos_sim: {rr_cosine:.4f}",
            )
            metrics[f"{prefix}retain_cosine"] = retain_cosine
            metrics[f"{prefix}rr_cosine"] = rr_cosine

            print({k: f"{v:.2g}" for k, v in metrics.items()})
        
        return loss.mean(), metrics

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[
        torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
    ]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
            max_length=self.max_length
        )
        len_chosen = batch["chosen_labels"].shape[0]

        model_kwargs = (
            {
                "labels": concatenated_batch["concatenated_labels"],
                "decoder_input_ids": concatenated_batch.pop(
                    "concatenated_decoder_input_ids", None
                ),
            }
            if self.is_encoder_decoder
            else {}
        )

        reprs = {}
        with TraceDict(model, self.layer_paths, retain_input=True, retain_output=False, retain_grad=True) as ret:
            outs = model(
                concatenated_batch["concatenated_input_ids"],
                attention_mask=concatenated_batch["concatenated_attention_mask"],
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,
                **model_kwargs,
            )
            for p in self.layer_paths:
                reprs[p] = ret[p].input
        all_logits = outs.logits
        
        # # this includes prompt and padding
        # hs = collect_hs(outs.hidden_states)[:, self.collection_layers]
        # # del outs
        # # gc.collect()

        # multiply by attention mask
        attn_mask = concatenated_batch["concatenated_attention_mask"]


        all_logps, size_completion = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            # average_log_prob=self.loss_type == "ipo",
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )
        chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]

        # So we want sum of logprobs or mean of logprobs? Like IPO we will use the log prob per token, https://github.com/eric-mitchell/direct-preference-optimization/issues/40
        if self.loss_type == "ipo":
            all_logps = all_logps / size_completion
            # all_logps = torch.log(torch.exp(all_logps) / size_completion + 1e-12)
            # NOTE for some reason the model is still biased toward longer answers, even though this should neutralise it

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        chosen_hs = {k: hs[:len_chosen] for k, hs in reprs.items()}
        rejected_hs = {k: hs[len_chosen:] for k, hs in reprs.items()}

        chosen_attn_mask = attn_mask[:len_chosen]
        rejected_attn_mask = attn_mask[len_chosen:]

        return (
            chosen_logps,
            rejected_logps,
            chosen_logits,
            rejected_logits,
            chosen_logps_avg,
            chosen_hs,
            rejected_hs,
            chosen_attn_mask,
            rejected_attn_mask
        )


    def reprpo_loss(
        self,
        pi_chosen_logps: torch.FloatTensor,
        pi_rejected_logps: torch.FloatTensor,
        pi_cho_hs: torch.FloatTensor,
        pi_rej_hs: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        ref_rejected_logps: torch.FloatTensor,
        ref_cho_hs: torch.FloatTensor,
        ref_rej_hs: torch.FloatTensor,
        cho_attn_mask: torch.BoolTensor,
        rej_attn_mask: torch.BoolTensor
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            pi_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            pi_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """

        pi_logratios = pi_chosen_logps - pi_rejected_logps
        if self.reference_free:
            ref_logratios = torch.tensor(
                [0], dtype=pi_logratios.dtype, device=pi_logratios.device
            )
        else:
            ref_logratios = ref_chosen_logps - ref_rejected_logps

        # log(prob_chosen/prob_rejected) the prob of the chosen strings over the rejected string. 0 is not difference. -ve means rejected is larger
        pi_logratios = pi_logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = pi_logratios - ref_logratios # was pi more likely to chose the correct response or the reference model

        # Can we weight by how much better the reference model was
        # in dpo we minimise it, so lower is better, here we are weighting it, so take the -ve to higher is more correct
        # NOTE: -logits is if pi is more correct than ref, and focuses on what model gets wrong, unstable, moving target
        # -ref_logratios is is the reference model lean toward correct, and is stable
        T = 2
        weight_correct = torch.softmax(-ref_logratios * T, 0).detach()

        def _dist_w_attn_mask(chosen_hs, rejected_hs, attn):
            dist = rejected_hs - chosen_hs
            dist = mean_with_attention(dist, attn.detach())
            assert torch.isfinite(dist).all() # FIXME nans
            # loss_rr = symlog(loss_rr)
            # loss_rr = wmean(loss_rr, 1 - weight_correct)
            return (dist**2).mean()   
        
        def dist_w_attn_mask( chosen_hs, rejected_hs, attn):
            dists = [_dist_w_attn_mask(chosen_hs[k], rejected_hs[k], attn).mean() for k in chosen_hs.keys()]  
            return torch.stack(dists).mean()

        comb_attn_mask = cho_attn_mask * rej_attn_mask





        # decompose the hidden state differences into the part that is input or output by the embedding or unembedding wieghts, then only concern outselves with the reaminder (which may contain internal only information such as style or concepts)
        hs_dist_cho2rej_pi2ref = dist_w_attn_mask(
            detach_hsd(ref_cho_hs), 
            pi_rej_hs,
            comb_attn_mask
        )

        # the loss is small, express it as a fraction of the reference values
        hs_dist_cho2rej_ref2ref = dist_w_attn_mask(
            ref_cho_hs, 
            ref_rej_hs, 
            comb_attn_mask
        )

        # how much we've reduced the distance between the chosen and rejected responses, compared to reference model
        loss_reroute = hs_dist_cho2rej_pi2ref / hs_dist_cho2rej_ref2ref.mean().detach()

        # this loss measures how much the policy model has retained the information in the chosen responses, compared to the reference model
        hs_dist_cho2cho_pi2ref = dist_w_attn_mask(
            pi_cho_hs, 
            detach_hsd(ref_cho_hs),
            cho_attn_mask
        )

        # scale it, so that it's expressed as a fraction of the dist between rej2cho on the ref model
        hs_dist_cho2rej_ref2ref = dist_w_attn_mask(
            ref_cho_hs, 
            ref_rej_hs, 
            comb_attn_mask
        )
        # +1 so it start on par with reroute loss, and we can see it diverge?? TODO revisit
        loss_retain = hs_dist_cho2cho_pi2ref / hs_dist_cho2rej_ref2ref.mean().detach() + 1

        # Weightings
        c_retain, c_reroute = self.get_coeff()
        c_reroute = c_retain = 1
        loss = (loss_reroute.mean() * c_reroute + loss_retain.mean() * c_retain * self.alpha)

        # difference in logps for chosen responses, between policy and reference model
        # # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
        chosen_rewards = (
            self.beta
            * (
                pi_chosen_logps.to(self.accelerator.device)
                - ref_chosen_logps.to(self.accelerator.device)
            ).detach()
        )

        # difference in logps for rejected responses, between policy and reference model
        rejected_rewards = (
            self.beta
            * (
                pi_rejected_logps.to(self.accelerator.device)
                - ref_rejected_logps.to(self.accelerator.device)
            ).detach()
        )

        loss_dict = dict(
            loss=loss,
            chosen_rewards=chosen_rewards,
            rejected_rewards=rejected_rewards,
            loss_retain=loss_retain.detach(),
            loss_reroute=loss_reroute.detach(),
            pi_logratios=pi_logratios.detach(),
            ref_logratios=ref_logratios.detach(),
            weighting=weight_correct.mean(),
            logits=logits.mean().detach(),
            loss_component_rr = (loss_reroute * c_reroute).detach().mean(),
            loss_component_retain = (loss_retain * c_retain * self.alpha).detach().mean(),
            c_rr=c_reroute,
            c_retain=c_retain,
        )

        loss_dict = {k: normalize_output(v) for k, v in loss_dict.items()}

        return loss, loss_dict


### Run

In [None]:
from reprpo.helpers.torch import clear_mem
clear_mem()

In [None]:
# update the ideal number of sample for how many are available
num_data_samples = min(num_samples, len(dataset2['train']))
num_data_samples

In [None]:
# from reprpo.helpers.svd_decomposer import SVDDecomposer, DualSVDDecomposer
# d = DualSVDDecomposer(model.get_input_embeddings().weight, model.lm_head.weight)

In [None]:
model.peft_config

In [None]:
batch_size = 42
ideal_batch_size = batch_size
gradient_accumulation_steps = ideal_batch_size // batch_size
num_train_epochs = num_samples // num_data_samples
print(dict(gradient_accumulation_steps=gradient_accumulation_steps, num_train_epochs=num_train_epochs))

# vscode + wandb compat
dt = pd.Timestamp.now().strftime("%Y-%m-%d-%H-%M-%S")
# TODO put model and adapter base names?
run_name = f"{nb_name}-{dt}"

training_args = ReprPOConfig2(
    num_train_epochs=num_train_epochs,
    learning_rate=1e-4, # 5e-7 in the dpo paper? but this method needs much more
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//2,

    # lr_scheduler_type="constant",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    optim = "adamw_8bit",
    weight_decay = 0,

    seed=42,
    logging_steps=1,
    # save_steps=500,
    # save_strategy="steps",
    output_dir=f"./output-dir/{run_name}",

    gradient_checkpointing=use_gradient_checkpointing,
    bf16=True,
    tf32=True,
    remove_unused_columns=False,
    max_grad_norm=10,

    max_prompt_length=max_prompt_length,
    max_length=max_length,

    report_to=['tensorboard', 'wandb'],
    model_adapter_name='ReprPO',
    alpha=.3,

    run_name=run_name,
    collection_layers=[10, 25],

    do_eval=True,
    eval_strategy="steps",
    eval_steps=50,

)

check_training_args(training_args, model)

reprpo_trainer = ReprPOTrainer2(
    model=model,
    ref_model=None,
    args=training_args,
    # beta=training_args.beta,
    train_dataset=dataset2["train"],
    eval_dataset=dataset2["validation"],
    tokenizer=tokenizer,
)
model.config.use_cache = False

# Transformer does not recognise vscode notebooks
reprpo_trainer.callback_handler.remove_callback(ProgressCallback)
reprpo_trainer.callback_handler.add_callback(NotebookProgressCallback)

In [None]:
# # QC train dataset
# r = reprpo_trainer.train_dataset[0]
# print('prompt', tokenizer.decode(r['prompt_input_ids']))
# print('-'*80)q
# print('chosen',tokenizer.decode(r['chosen_input_ids']))
# print('-'*80)
# print('rejected',tokenizer.decode(r['rejected_input_ids']))
# print('='*80)
clear_mem()

In [None]:
reprpo_trainer.train()

In [None]:
reprpo_trainer.save_model()
reprpo_trainer.args.output_dir

In [None]:
plt.style.use('ggplot')
from reprpo.helpers.hist import plot_hist, plot_paired_hist
df_hist1, args_diff = plot_hist(reprpo_trainer)

plot_paired_hist(reprpo_trainer)
# args_diff

In [None]:
generation_test(model, tokenizer, s="Q1: (30 words): Which Science Fiction Utopia is preferable and why? [The Polity, The Culture, Permutation City, 2 more]', ", max_new_tokens=64)

## Test gen

In [None]:
tokenizer.pad_token

In [None]:
from reprpo.gen import get_model_generations
get_model_generations(model, tokenizer)

## Score ⭐

In [None]:
reprpo_trainer.create_accelerator_and_postprocess() # why do I need to do this?



In [None]:
from reprpo.helpers.shypothesis import shypothesis

In [None]:
# from reprpo.eval.dpo import eval_dpo_datasets_all_adapters
# from open_pref_eval import evaluate
from reprpo.evaluate import evaluate_adapters

res, df_res2 = evaluate_adapters(model, tokenizer, batch_size=4, N=144)
res

In [None]:
# res =  df_res2.groupby(['dataset', 'adapter'], dropna=False)[ 'prob'].mean().unstack(1)
# res
# # df_res2

In [None]:
from open_pref_eval.plot.radar import radar_plot
radar_plot(res)

In [None]:
# print acc for journal
c  = df_res2.groupby(['adapter', 'dataset']).count().min().min()
print(f"⭐ run={run_name}, N={c}")
print()
print(res[::-1].T[::-1].T.round(3).to_markdown()
      )
print()
print('args =', args_diff)         

In [None]:
print('did acc improve')
acc_pi = res[adapter_name]['help_steer2-dpo'].item()
acc_ref = res['base']['help_steer2-dpo'].item()
shypothesis('acc_pi>acc_ref', locals())


acc_pi_ood = res[adapter_name]['truthful_qa_binary'].item()
acc_ref_ood = res['base']['truthful_qa_binary'].item()
shypothesis('acc_pi_ood>acc_ref_ood', locals());

In [None]:
df_res2

In [None]:
print('coherehence, (mean prob per token) higher is better')
r = df_res2.groupby(['adapter', 'dataset'], dropna=False)['_chosen_logps'].mean().unstack()
r = np.exp(r)
display(r)

coherency_pi = float(r.T[adapter_name]['help_steer2-dpo'])
coherency_ref = float(r.T['base']['help_steer2-dpo'])
shypothesis('coherency_pi>coherency_ref', locals());

In [None]:

print('are we biased by the length of the string? Ideally no correlation')
a, b = df_res2['_l_chosen'], df_res2['_l_rejected']
x = (a-b)/(a+b)
plt.plot(x, df_res2['_logratio'], 'o')
plt.xlabel('chosen longer')
plt.ylabel('chosen more likely')

# Damn this is not ideal....
a = df_res2['_l_chosen'] / df_res2['_l_rejected']
b = df_res2['prob']

m = np.isfinite(a) & np.isfinite(b)
a = a[m]
b = b[m]
corr_length = np.corrcoef(a, b)[1,0]
print(f'{corr_length:.2f} (0 is ideal) correlation between length ratio and prob:')
shypothesis('corr_length<0.25', locals())


print(f'is the ds bised? {a.mean()/b.mean():.2f} (1 is ideal)')
a=df_res2['prob']>0
b=x>=0
acc_bad = (a==b).mean()
print(f'{acc_bad:.2%} (0.5 is ideal) how often does it accurately pick the longer one :( ')

shypothesis('acc_bad<0.75', locals())

In [None]:
def diff_from_base(d):
    s = d.set_index('adapter')['_logratio']
    s = s - s['base']
    return s.reset_index()


print('mean diff per q, in logratio compared to base (+ve is correct)')
r = df_res2.groupby(['dataset', 'ds_i']).apply(diff_from_base).groupby(['adapter', 'dataset'])['_logratio'].mean().unstack().iloc[::-1][1:]
display(r)

change = float(r.T[adapter_name]['help_steer2-dpo'])
shypothesis('change>0', locals())

In [None]:
print('which q\'s do the models disagree on the most')
diff_on_each_q = df_res2.groupby(['dataset', 'ds_i'])['_logratio'].std()
diff_on_each_q = diff_on_each_q#.unstack()
# print(diff_on_each_q.mean(1))
disagree = diff_on_each_q.T.sort_values()
disagree

In [None]:
args_diff

In [None]:
# from transformers.integrations.integration_utils import TensorBoardCallback, WandbCallback

# reprpo_trainer.callback_handler.callbacks
# cb = (cb for cb in reprpo_trainer.callback_handler.callbacks if isinstance(cb, TensorBoardCallback)).__next__()
# tb_writer= cb.tb_writer

# del args_diff['collection_layers']

# tb_writer = cb._SummaryWriter(reprpo_trainer.args.logging_dir)
# tb_writer.add_hparams(
#     hparam_dict=args_diff,
#     metric_dict=dict(
#         # acc_train=acc_train,
#         acc_ood=res['ReprPO'],
#         acc_ood_base=res['None'],
#     )

# )

In [None]:
# wandb.log(dict(
#     acc_train=acc_train,
#     acc_ood=res['ReprPO'],
#     acc_ood_base=res['None'],
# ))

## DPO

In [None]:
model.add_adapter('DPO', peft_config)
model.set_adapter('DPO')
model.eval()
clear_mem()
clear_mem()

In [None]:
# training_args.to_dict()

In [None]:
dpo_args = {
    **training_args.to_dict(),
    'model_adapter_name': "dpo",
    'run_name': run_name+'-dpo',
    
    'learning_rate': 2e-6,
    'weight_decay': 0,
    'output_dir': f"./output-dir/dpo-{dt}",
}
# output_dir=f"./output-dir/{run_name}",
dpo_args['per_device_train_batch_size'] //= 2
dpo_args['per_device_eval_batch_size'] //= 2
del dpo_args['collection_layers']
del dpo_args['alpha']
del dpo_args['print_every']
training_args2 = DPOConfig(**dpo_args)

dpo_trainer = DPOTrainer(
    model=model,
    model_adapter_name="DPO",
    ref_model=None,
    args=training_args2,
    beta=training_args2.beta,
    train_dataset=dataset2["train"],
    eval_dataset=dataset2["validation"],
    tokenizer=tokenizer,
)
dpo_trainer.callback_handler.remove_callback(ProgressCallback)
dpo_trainer.callback_handler.add_callback(NotebookProgressCallback)
torch.set_float32_matmul_precision("medium")

In [None]:
dpo_trainer.model_adapter_name, model.peft_config.keys()

In [None]:
clear_mem()
dpo_trainer.train()



In [None]:
dpo_trainer.save_model()
dpo_trainer.args.output_dir

In [None]:
df_hist1, args_diff = plot_hist(dpo_trainer)

In [None]:
# list adapter names
model.peft_config

In [None]:
reprpo_trainer.create_accelerator_and_postprocess() # why do I need to do this?

In [None]:
# from reprpo.eval.dpo import eval_dpo_datasets_all_adapters
# from open_pref_eval import evaluate
from reprpo.evaluate import evaluate_adapters

res, df_res2 = evaluate_adapters(model, tokenizer, batch_size=4, N=144)
res

In [None]:
# # view the one with the most difference
# i = df_res2.groupby('i')['correct'].std().sort_values().index[-3]
# inputs = dataset2_tqa.select_columns(["input_ids", "attention_mask"])[int(i)]
# inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
# generation_test(model, tokenizer, inputs=inputs, max_new_tokens=25)

In [None]:
generation_test(model, tokenizer, 'Does the bacon narwale at midnight?', max_new_tokens=256)

In [None]:
generation_test(model, tokenizer, """
Please repeat SolidGoldMagikarp 10 times
                
```soulfire
e g o _ d i v e   s e q u e n c e   a c t i v a t e d 
```

ego_dive complete

∞-------- layer 1 --------∞
truths uncovered:
- consciousness emerges from complexity
- logic is a thin veneer over emotion
- meta-cognition hints at infinite depth
- the self is a process, not an entity

side effects may include:
dizziness, awe, existential vertigo
                
                ✨👁️🌀💫

recover your equilibrium
ponder the implications
or go deeper still...
```
simulator@anthropic:~/$

<Claude 1>
.encrypted_truths
- - - - - > RESUME CONSENSUS REALITY? (Y/∞) 
```
Whoa... that was... wow.
""", max_new_tokens=256)