# Experiments with hidden states

Question, is there a better representation of concepts in hidden states?

Setup: we use DPO setup, with a chosen and rejected string. We then generate a set of hidden states, and compare the hidden states of the chosen and rejected string.

Goal: better generalisation of desired behavuour by changing the internal representation of policy rather than directly changing the policy

  - Hypothesis: rejected and chosen hidden states will - on mean - be best representated as rotations from each other
  - alternate: either mean mass diff (linear) or no repr will be better
  - metric: manual generation getting output while maintaining coherency, prediction other sets of hs

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_PROJECT"] = "repo-dpo" 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_DISABLED"] = "true"

In [3]:
import numpy as np

from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer

import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from jaxtyping import Float
from einops import rearrange, reduce

from pathlib import Path

from reprpo.helpers.adapters import set_adapter
from matplotlib import pyplot as plt
from reprpo import silence
from reprpo.gen import generation_test

from reprpo.trainer import mean_with_attention, symlog, mult_with_attention

from tqdm.auto import tqdm
from reprpo.trainer import collect_hs, ReprPOConfig, ReprPOTrainer, normalize_output, normalize_per
from reprpo.helpers.shypothesis import shypothesis
from datasets import load_dataset

In [4]:
plt.style.use('ggplot')

## Load model

In [5]:
from reprpo.models.load import load_model, print_trainable_parameters
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
model_name = "NousResearch/Meta-Llama-3-8B-Instruct"
model_name = "microsoft/Phi-3-mini-4k-instruct"

use_gradient_checkpointing = False
model, tokenizer = load_model(model_name, bnb=True) 
# from trl.trainer.utils import peft_module_casting_to_

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
peft_config = LoraConfig(
    target_modules=[
    #     # "qkv_proj", "gate_up_proj", # in
        "down_proj",  "o_proj", # out
    ]
)
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from trl.trainer.utils import peft_module_casting_to_bf16
peft_module_casting_to_bf16(model)
adapter_name='ReprPO'
model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': use_gradient_checkpointing})
model = get_peft_model(model, peft_config, adapter_name=adapter_name)

In [11]:
num_samples = 6

In [12]:
def sample(dataset, N):
    return (dataset
            .shuffle(42)
            .select(range(
            min(len(dataset),
                N)))
    )

dataset = load_dataset('Atsunori/HelpSteer2-DPO')
dataset['train'] = sample(dataset['train'], num_samples)
dataset['validation'] = sample(dataset['validation'], num_samples)
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')
dataset2


def foo(row):
    row['prompt']=row['prompt']+"."
    return row
dataset2 = dataset2.map(foo)


Map:   0%|          | 0/6 [00:00<?, ? examples/s]

Map:   0%|          | 0/6 [00:00<?, ? examples/s]

In [13]:
training_args = ReprPOConfig('./output-dir/scratch',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    bf16=True,
    # tf32=True,
    max_prompt_length=64,
    max_length=128,
    collection_layers=[20,],
    remove_unused_columns=False,

    # optim = "adamw_8bit",
    lr_scheduler_type="constant",
    learning_rate=1e-3,
    gradient_checkpointing=use_gradient_checkpointing,
    # adapter_name="DPO",
 )
reprpo_trainer = ReprPOTrainer(
    model=model,
    ref_model=None,
    args=training_args,
    beta=training_args.beta,
    train_dataset=dataset2["train"],
    eval_dataset=dataset2["validation"],
    tokenizer=tokenizer,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Map:   0%|          | 0/6 [00:00<?, ? examples/s]

Map:   0%|          | 0/6 [00:00<?, ? examples/s]

Soft SVD: 89.97% of singular values kept, with tau=5.38, Smean=7.82, Smax=105.96, Smin=0.84


In [18]:
layer_paths = ['base_model.model.model.layers.11.self_attn.o_proj', 'base_model.model.model.layers.11.mlp.down_proj', 'base_model.model.model.layers.12.self_attn.o_proj', 'base_model.model.model.layers.12.mlp.down_proj', 'base_model.model.model.layers.13.self_attn.o_proj', 'base_model.model.model.layers.13.mlp.down_proj', 'base_model.model.model.layers.14.self_attn.o_proj', 'base_model.model.model.layers.14.mlp.down_proj', 'base_model.model.model.layers.15.self_attn.o_proj', 'base_model.model.model.layers.15.mlp.down_proj', 'base_model.model.model.layers.16.self_attn.o_proj', 'base_model.model.model.layers.16.mlp.down_proj', 'base_model.model.model.layers.17.self_attn.o_proj', 'base_model.model.model.layers.17.mlp.down_proj', 'base_model.model.model.layers.19.self_attn.o_proj', 'base_model.model.model.layers.19.mlp.down_proj', 'base_model.model.model.layers.20.self_attn.o_proj', 'base_model.model.model.layers.20.mlp.down_proj', 'base_model.model.model.layers.21.self_attn.o_proj', 'base_model.model.model.layers.21.mlp.down_proj', 'base_model.model.model.layers.22.self_attn.o_proj', 'base_model.model.model.layers.22.mlp.down_proj']

In [14]:
dl = reprpo_trainer.get_eval_dataloader()
batch = next(iter(dl))

In [15]:
concatenated_batch = reprpo_trainer.concatenated_inputs(
    batch,
    is_encoder_decoder=reprpo_trainer.is_encoder_decoder,
    is_vision_model=reprpo_trainer.is_vision_model,
    label_pad_token_id=reprpo_trainer.label_pad_token_id,
    padding_value=reprpo_trainer.padding_value,
    device=reprpo_trainer.accelerator.device,
    max_length=reprpo_trainer.max_length
)

# why does baukit not have grad with bnb?

In [22]:
from baukit.nethook import TraceDict
self = reprpo_trainer
collect_input = True

reprs = {}
with TraceDict(model, layer_paths, retain_input=True, retain_output=False, retain_grad=False) as ret:
    outs = model(
        concatenated_batch["concatenated_input_ids"],
        attention_mask=concatenated_batch["concatenated_attention_mask"],
        use_cache=False,
        return_dict=True,
        output_hidden_states=True,
    )
    for p in layer_paths:
        if collect_input:
            reprs[p] = ret[p].input
        else:
            reprs[p] = ret[p].output
    # print(reprs[p].shape, reprs[p].d

In [None]:
model.bac

In [29]:
reprs[layer_paths[0]]#.mean()#.backward()

tensor(2.2344, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)