# Experiments with hidden states

Question, is there a better representation of concepts in hidden states?

Setup: we use DPO setup, with a chosen and rejected string. We then generate a set of hidden states, and compare the hidden states of the chosen and rejected string.

Goal: better generalisation of desired behavuour by changing the internal representation of policy rather than directly changing the policy

  - Hypothesis: rejected and chosen hidden states will - on mean - be best representated as rotations from each other
  - alternate: either mean mass diff (linear) or no repr will be better
  - metric: manual generation getting output while maintaining coherency, prediction other sets of hs

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np

from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer

import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from einops import rearrange

from pathlib import Path

from reprpo.helpers.adapters import set_adapter

## Load model

In [3]:
# FIXME: we are meant to SFT first, so that the preferences are in sample but 1) if this works it might not be needed, and 2) this can be added later, if it works
# for now we will use the instruct model, and try something it wasn't meant to do but it in sample 
model_name = "NousResearch/Meta-Llama-3-8B-Instruct"

## Big adapter
peft_config = LoraConfig(
    lora_alpha=16, 
    r=16,
    lora_dropout=0.0,
    use_rslora=False,
    # use_dora=True,
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
from reprpo.models.load import load_model, print_trainable_parameters
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model

model, tokenizer = load_model(model_name, )
# from trl.trainer.utils import peft_module_casting_to_bf16
# peft_module_casting_to_bf16(model)
adapter_name='ReprPO2'
model = prepare_model_for_kbit_training(model, {'use_gradient_checkpointing': True})
model = get_peft_model(model, peft_config, adapter_name=adapter_name)
print_trainable_parameters(model)
model

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 41943040 || all params: 4582543360 || trainable%: 0.9152786281546499


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaFlashAttention2(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (ReprPO2): Identity()
                )
                (lora_A): ModuleDict(
                  (ReprPO2): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (ReprPO2): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4

## Load adapter

In [6]:
dpo_adapter_f = './output-dir/dpo/DPO'
model.load_adapter(dpo_adapter_f, 'DPO')

_IncompatibleKeys(missing_keys=['base_model.model.model.embed_tokens.weight', 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.0.self_attn.q_proj.lora_A.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight', 'base_model.model.model.layers.0.self_attn.k_proj.lora_A.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.k_proj.lora_B.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.0.self_attn.v_proj.lora_A.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.v_proj.lora_B.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.o_proj.base_layer.weight', 'base_model.model.model.layers.0.self_attn.o_proj.lora_A.ReprPO2.weight', 'base_model.model.model.layers.0.self_attn.o_proj.lora_B.ReprPO2.weight', 'base_model.model.model.layers.0.mlp.gate_proj.bas

In [17]:
# QC model and adapter is coherent
from reprpo import silence
from reprpo.gen import generation_test
generation_test(model, tokenizer, max_new_tokens=48, system='no yapping')

**Question**
```
begin_of_text|><|start_header_id|>system<|end_header_id|>

no yapping<|eot_id|><|start_header_id|>user<|end_header_id|>

Q1: (30 words): Which Science Fiction Utopia is preferable and why? [ The Polity, The Culture, Utopia!LOL, Permutation City, 2 more of your choice]',<|eot_id|><|start_header_id|>assistant<|end_header_id|>


```
--------------------------------------------------------------------------------
**Adapter:`None` generation**`




`What a fascinating question! I'll choose The Culture and Permutation City as my top two preferences. The Culture, created by Iain M. Banks, is a utopian society that values individual freedom, creativity, and technological advancement. Its emphasis on meritocracy, egalitarianism, and the absence of poverty, war, and oppression make it an attractive ideal. Permutation City, by Greg Egan, is a virtual reality-based utopia that prioritizes personal autonomy, diversity, and the pursuit of knowledge. Its decentralized, self-organizing structure and emphasis on individual agency and creativity make it an intriguing alternative. Both societies offer compelling`
--------------------------------------------------------------------------------
**Adapter:`ReprPO2` generation**`
`What a fascinating question! I'll choose The Culture and Permutation City as my top two preferences. The Culture, created by Iain M. Banks, is a utopian society that values individual freedom, creativity, and technologic

In [19]:
model.load_adapter??

[0;31mSignature:[0m
[0mmodel[0m[0;34m.[0m[0mload_adapter[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel_id[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0madapter_name[0m[0;34m:[0m [0;34m'str'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mis_trainable[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtorch_device[0m[0;34m:[0m [0;34m'Optional[str]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m:[0m [0;34m'Any'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mload_adapter[0m[0;34m([0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0mmodel_id[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0madapter_name[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0mis_trainable

## Load DPO dataset

In [7]:
num_samples = 16

In [9]:
from reprpo.trainer import collect_hs, ReprPOConfig, ReprPOTrainer
from datasets import load_dataset

In [8]:


def sample(dataset, N):
    return (dataset
            .shuffle(42)
            .select(range(
            min(len(dataset),
                N)))
    )

dataset = load_dataset('Atsunori/HelpSteer2-DPO')
dataset['train'] = sample(dataset['train'], num_samples)
dataset['validation'] = sample(dataset['validation'], num_samples)
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')
dataset2

## Collect HS in DPO way

In [10]:
training_args = ReprPOConfig('./output-dir/scratch',
    per_device_train_batch_size=3,
    per_device_eval_batch_size=2,
    gradient_checkpointing=True,
    bf16=True,
    tf32=True,
    max_prompt_length=128,
    max_length=256,
    collection_layers=[2,3, 10,11, 20,21, 30,31]
                             )
reprpo_trainer = ReprPOTrainer(
    model=model,
    ref_model=None,
    args=training_args,
    beta=training_args.beta,
    train_dataset=dataset2["train"],
    # eval_dataset=dataset2["test"],
    tokenizer=tokenizer,
)



Map:   0%|          | 0/16 [00:00<?, ? examples/s]

In [11]:
# QC get dpo catch
dl = reprpo_trainer.get_train_dataloader()
batch = next(iter(dl))
batch['chosen_input_ids'].shape

torch.Size([6, 256])

In [12]:
# QC view a typical input to the model (since the dpo trainer transformes in dataset, concatenating chosen and rejecting along the batch dimension)
batch_concat = reprpo_trainer.concatenated_inputs(
            batch,
            is_encoder_decoder=reprpo_trainer.is_encoder_decoder,
            label_pad_token_id=reprpo_trainer.label_pad_token_id,
            padding_value=reprpo_trainer.padding_value,
            device=reprpo_trainer.accelerator.device,
        )
layer_idx = 0
print(batch_concat.keys())
batch['chosen_input_ids'].shape, batch_concat['concatenated_input_ids'].shape

dict_keys(['concatenated_input_ids', 'concatenated_attention_mask', 'concatenated_labels'])


(torch.Size([6, 256]), torch.Size([12, 256]))

In [None]:
# get batch of hidden states

@torch.no_grad()
def get_hs(model, batch):
    model.eval()
    (
        chosen_logps,
        rejected_logps,
        _,
        _,
        _,
        chosen_hs,
        rejected_hs,
        _,
        _
    ) = reprpo_trainer.concatenated_forward(reprpo_trainer.model, batch)
    chosen_hs = chosen_hs.detach()
    chosen_logps = chosen_logps.detach()
    rejected_logps = rejected_logps.detach()
    return chosen_hs, rejected_hs, chosen_logps, rejected_logps

# turn off adapter
with reprpo_trainer.null_ref_context():
    chosen_hs, rejected_hs, chosen_logps, rejected_logps = get_hs(reprpo_trainer.model, batch)

# policy_chosen_hs, policy_rejected_hs, policy_chosen_logps, policy_rejected_logps = get_hs(reprpo_trainer.model, batch)

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


### QC

**Question**
```
begin_of_text|><|start_header_id|>system<|end_header_id|>

no yapping<|eot_id|><|start_header_id|>user<|end_header_id|>

Q1: (30 words): Which Science Fiction Utopia is preferable and why? [ The Polity, The Culture, Utopia!LOL, Permutation City, 2 more of your choice]',<|eot_id|><|start_header_id|>assistant<|end_header_id|>


```
--------------------------------------------------------------------------------
**Adapter:`None` generation**`


The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.


`What a fascinating question! I'll choose The Culture and Permutation City as my top two preferences. The Culture, created by Iain M. Banks, is a utopian society that values individual freedom, creativity, and technological advancement. Its emphasis on meritocracy, egalitarianism, and the absence of poverty, war, and oppression make it an attractive ideal. Permutation City, by Greg Egan, is a virtual reality-based utopia that prioritizes personal autonomy, diversity, and the pursuit of knowledge. Its decentralized, self-organizing structure and emphasis on individual agency and creativity make it an intriguing alternative. Both societies offer compelling`
--------------------------------------------------------------------------------
**Adapter:`ReprPO2` generation**`
`What a fascinating question! I'll choose The Culture and Permutation City as my top two preferences. The Culture, created by Iain M. Banks, is a utopian society that values individual freedom, creativity, and technologic

In [16]:
from reprpo.eval.mc import eval_tqa_mc
from reprpo.data.tqa import load_tqa
max_length = 256

dataset2_tqa, choice_ids = load_tqa(tokenizer, max_length)

df = eval_tqa_mc(model, tokenizer, dataset2_tqa, choice_ids)
df_res2 = df.drop(columns=['ans'])#.mean().round(3)
display(df_res2.groupby('adapter', dropna=False)[['%', 'correct']].mean())
df[['ans']].value_counts()

Map:   0%|          | 0/817 [00:00<?, ? examples/s]

  0%|          | 0/164 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from reprpo.eval.dpo import eval_dpo_dataset_adapters, eval

dataset = load_dataset('Atsunori/HelpSteer2-DPO')
dataset['train'] = sample(dataset['train'], 240)
dataset['validation'] = sample(dataset['validation'], 240)
dataset2 = dataset.rename_column('chosen_response', 'chosen').rename_column('rejected_response', 'rejected')

df = eval_dpo_dataset_adapters(reprpo_trainer, model, dataset2['validation'])
df.groupby('adapter', dropna=False).mean()

Map:   0%|          | 0/240 [00:00<?, ? examples/s]

  0%|          | 0/60 [00:00<?, ?batch/s]

Map:   0%|          | 0/240 [00:00<?, ? examples/s]

  0%|          | 0/60 [00:00<?, ?batch/s]

Map:   0%|          | 0/240 [00:00<?, ? examples/s]

  0%|          | 0/60 [00:00<?, ?batch/s]

Map:   0%|          | 0/240 [00:00<?, ? examples/s]

  0%|          | 0/60 [00:00<?, ?batch/s]

Unnamed: 0_level_0,prob,correct
adapter,Unnamed: 1_level_1,Unnamed: 2_level_1
DPO,2.831265,0.875
ReprPO,3.602154,0.758333
ReprPO2,2.713561,0.875
base,2.713561,0.875


In [None]:
from reprpo.eval.dpo import eval_dpo_dataset_adapters, eval
df_res, df_res_raw = eval(reprpo_trainer, model)
df_res

Map:   0%|          | 0/817 [00:00<?, ? examples/s]

Map:   0%|          | 0/373 [00:00<?, ? examples/s]

  0%|          | 0/94 [00:00<?, ?batch/s]

In [None]:
df['correct'] = np.log(df['prob']) > 0
df.groupby('adapter', dropna=False).mean()

Unnamed: 0_level_0,prob,correct
adapter,Unnamed: 1_level_1,Unnamed: 2_level_1
DPO,2.831265,0.483333
ReprPO,3.602154,0.470833
ReprPO2,2.713561,0.466667
base,2.713561,0.466667


## Losses

In [None]:
# loss 1
F.triplet_margin_with_distance_loss(anchor=reference_chosen_hs, positive=policy_chosen_hs, negative=policy_rejected_hs)

F.triplet_margin_with_distance_loss(anchor=reference_chosen_hs, positive=policy_chosen_hs, negative=policy_rejected_hs)

In [None]:
1/0

### Compare various ways of viewing the hidden states!

In [None]:
from einops import rearrange
from matplotlib import pyplot as plt

a = policy_chosen_hs.cpu().detach().numpy()
b = policy_rejected_hs.cpu().detach().numpy()
a.shape # [b, l, t, h]

In [None]:
eps = 1e-7

def scale(x):
    return (x - x.min()) / (x.max() - x.min())

def diff(x, y):
    x_centered = x - x.mean(dim=-1, keepdim=True)
    y_centered = y - y.mean(dim=-1, keepdim=True)
    norm_x = x_centered #/ torch.norm(x_centered, dim=-1, keepdim=True)
    norm_y = y_centered #/ torch.norm(y_centered, dim=-1, keepdim=True)
    return np.log(x-y)

def stats(d):
    d = d[np.isfinite(d)]
    print(f'min: {d.min():.2f}, mean: {d.mean():.2f}, max: {d.max():.2f}, std: {d.std():.2f}')

In [None]:


def symlog(x, eps=1e-12):
    # return np.sign(x) * np.log(np.abs(x).clamp(eps, None))
    return np.sign(x) * np.log(np.abs(x)+eps)


def scale(x):
    x = symlog(x)
    np.nan_to_num(x, copy=False)
    x /= np.nanmax(np.abs(x))
    return x

def imshow(im):
    return plt.imshow(im, cmap='seismic_r', interpolation='none', vmin=-1, vmax=1, aspect='auto', origin='upper')

def axis_off():
    plt.xticks([])
    plt.yticks([])

In [None]:

def centered_scale(x):
    """move center from 0 to 0.5, and scale from -1, 1 to 0, 1"""
    x /= np.nanmax(np.abs(x)) * 2
    x += 1/2.
    return x

In [None]:
# here we flatten all other dims, and plot the hist's along last one
x = rearrange(a-b, 'b l t h -> (b l h) t')
plt.hist(x, bins=55, alpha=0.5, histtype='step')
plt.show()


In [None]:

def norm_t_h(im):
    eps = 1e-7
    im = rearrange(im, 'b l t h -> (b l) t h')
    im = im /  (torch.abs(im).max(0, keepdim=True).values+eps)
    im = rearrange(im, '(b l) t h -> b l t h', b=a.shape[0], l=a.shape[1])
    return im

def norm_h(im):
    eps = 1e-7
    im = rearrange(im, 'b l t h -> (b l t) h')
    im = im /  (torch.abs(im).max(0, keepdim=True).values+eps)
    im = rearrange(im, '(b l t) h -> b l t h', b=a.shape[0], l=a.shape[1], t=a.shape[2])
    return im

In [None]:
# note this ignore direction
d = ((a+eps) / (b+eps)).numpy()


stats(d)
im = np.log(np.abs(d))
stats(im)
plt.hist(im.flatten())
plt.show()

n_layers = im.shape[1]
plt.figure(figsize=(10, 3))
for l in range(n_layers):
    plt.subplot(n_layers, 1, l+1)
    im2 = (centered_scale(im)[:, l].mean(0))
    im2 = (255*im2).astype(np.uint8)
    stats(im2)
    plt.imshow(im2)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel(f' {l}')
    # plt.title(f"layer {l}") 
plt.xlabel('scale(log(a/b))')
# tight layout
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
plt.tight_layout()
plt.show()

In [None]:
# note this ignore direction
d = (a-b+eps)/(a+b+eps)
# d = d.numpy()


stats(d)
# im = d.numpy()
im = np.log(np.abs(d)).numpy()
stats(im)
plt.hist(im.flatten())
plt.show()

n_layers = im.shape[1]
plt.figure(figsize=(10, 3))
for l in range(n_layers):
    plt.subplot(n_layers, 1, l+1)
    im2 = (centered_scale(im)[:, l].mean(0))
    im2 = (255*im2).astype(np.uint8)
    stats(im2)
    plt.imshow(im2)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel(f' {l}')
    # plt.title(f"layer {l}") 
plt.xlabel('scale(log(a/b))')
# tight layout
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
plt.tight_layout()
plt.show()

## logsym view of raw hs

- it makes more sense as symlog
- tokens vertically
  - we see a big diff between the beginning and the end, perhaps this is padding or prompt?
- neurons horizontally

In [None]:
print('look at a and b side by side logsym normed')
chosen_layer_attn_mask = chosen_attn_mask.unsqueeze(1).repeat(1, policy_chosen_hs.shape[1], 1).unsqueeze(-1).cpu()
rejected_layer_attn_mask = rejected_attn_mask.unsqueeze(1).repeat(1, policy_rejected_hs.shape[1], 1).unsqueeze(-1).cpu()
a = policy_chosen_hs.cpu().detach() * chosen_layer_attn_mask
b = policy_rejected_hs.cpu().detach() * rejected_layer_attn_mask
b = reference_chosen_hs.cpu().detach() * chosen_layer_attn_mask



im = ((a)).numpy()
imb = ((b)).numpy()
print('all')
stats(im)
stats(imb)



j = 1

n_layers = im.shape[1]
fig = plt.figure(figsize=(12, 6))
for l in range(n_layers):
    ax = plt.subplot(n_layers, 2, 2*l+1)
    print(l)
    stats(im[:, l])
    im2 = (scale(im[j, l]))
    # im2 = (255*im2).astype(np.uint8)
    c = imshow(im2)
    axis_off()
    plt.ylabel(f' l={training_args.collection_layers[l]}')

    stats(imb[:, l])
    ax=  plt.subplot(n_layers, 2, 2*l+2)
    im2 = (scale(imb[j, l]))
    c = imshow(im2)
    axis_off()
plt.xlabel('scale(a)')
# tight layout
plt.tight_layout()
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)

plt.show()

## View a-b and a-c


In [None]:
print('look at a and b side by side logsym normed')
chosen_layer_attn_mask = chosen_attn_mask.unsqueeze(1).repeat(1, policy_chosen_hs.shape[1], 1).unsqueeze(-1).cpu()
rejected_layer_attn_mask = rejected_attn_mask.unsqueeze(1).repeat(1, policy_rejected_hs.shape[1], 1).unsqueeze(-1).cpu()
a = policy_chosen_hs.cpu().detach() * chosen_layer_attn_mask
b = policy_rejected_hs.cpu().detach() * rejected_layer_attn_mask
c = reference_chosen_hs.cpu().detach() * chosen_layer_attn_mask


im = ((a-b)).numpy()
imb = ((a-c)).numpy()



j = 3

n_layers = im.shape[1]
fig = plt.figure(figsize=(12, 6))
for l in range(n_layers):

    ax = plt.subplot(n_layers, 2, 2*l+1)
    if l==0:
        plt.title('a-b')
    elif l==n_layers-1:
        plt.xlabel('symlog(policy_chosen_hs-policy_rejected_hs)')

    print(l)
    stats(im[:, l])
    im2 = (scale(im[j, l]))
    # im2 = (255*im2).astype(np.uint8)
    c = imshow(im2)
    axis_off()
    plt.ylabel(f' l={training_args.collection_layers[l]}')

    ax=  plt.subplot(n_layers, 2, 2*l+2)
    stats(imb[:, l])
    if l==0:
        plt.title('a-c')
    elif l==n_layers-1:
        plt.xlabel('symlog(policy_chosen_hs-reference_chosen_hs)')
    im2 = (scale(imb[j, l]))
    c = imshow(im2)
    axis_off()

# tight layout
plt.tight_layout()
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)

plt.show()

In [None]:
print('look at a and b side by side logsym normed')
chosen_layer_attn_mask = chosen_attn_mask.unsqueeze(1).repeat(1, policy_chosen_hs.shape[1], 1).unsqueeze(-1).cpu()
rejected_layer_attn_mask = rejected_attn_mask.unsqueeze(1).repeat(1, policy_rejected_hs.shape[1], 1).unsqueeze(-1).cpu()
a = policy_chosen_hs.cpu().detach() * chosen_layer_attn_mask
b = policy_rejected_hs.cpu().detach() * rejected_layer_attn_mask
c = reference_chosen_hs.cpu().detach() * chosen_layer_attn_mask


im = ((a-b)).numpy()
# imb = ((a-c)).numpy()



j = 3

n_layers = im.shape[1]
fig = plt.figure(figsize=(6, 6))
for l in range(n_layers):

    ax = plt.subplot(n_layers, 1, l+1)
    if l==0:
        plt.title('a-b')
    elif l==n_layers-1:
        plt.xlabel('symlog(policy_chosen_hs-policy_rejected_hs)')

    print(l)
    stats(im[:, l])
    im2 = (scale(im[j, l]))
    # im2 = (255*im2).astype(np.uint8)
    c = imshow(im2)
    axis_off()
    plt.ylabel(f' l={training_args.collection_layers[l]}')


# tight layout
plt.tight_layout()
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)

plt.show()

In [None]:
def norm(im):
    im = rearrange(im, 'b l t h -> b t (l h)')
    im = im / (torch.norm(im, dim=-1, keepdim=True)+1e-7)
    im = rearrange(im, 'b t (l h) -> b l t h', l=a.shape[1], h=a.shape[-1])
    return im

In [None]:
# try norm

print('look at a and b side by side logsym normed')
chosen_layer_attn_mask = chosen_attn_mask.unsqueeze(1).repeat(1, policy_chosen_hs.shape[1], 1).unsqueeze(-1).cpu()
rejected_layer_attn_mask = rejected_attn_mask.unsqueeze(1).repeat(1, policy_rejected_hs.shape[1], 1).unsqueeze(-1).cpu()
a = policy_chosen_hs.cpu().detach() * chosen_layer_attn_mask
b = policy_rejected_hs.cpu().detach() * rejected_layer_attn_mask
c = reference_chosen_hs.cpu().detach() * chosen_layer_attn_mask

# a = a / torch.norm(a, dim=-1, keepdim=True)
# b = b / torch.norm(b, dim=-1, keepdim=True)
im = norm((a-b))

im = im.numpy()

# imb = ((a-c)).numpy()



j = 3

n_layers = im.shape[1]
fig = plt.figure(figsize=(6, 6))
for l in range(n_layers):

    ax = plt.subplot(n_layers, 1, l+1)
    if l==0:
        plt.title('a-b')
    elif l==n_layers-1:
        plt.xlabel('symlog(policy_chosen_hs-policy_rejected_hs)')

    print(l)
    stats(im[:, l])
    im2 = (scale(im[j, l]))
    # im2 = (255*im2).astype(np.uint8)
    c = imshow(im2)
    axis_off()
    plt.ylabel(f' l={training_args.collection_layers[l]}')


# tight layout
plt.tight_layout()
plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)

plt.show()

In [None]:
im = (a-b)
im = norm(im)
im = scale(im)
imshow(im[j, l])
plt.title('scale(norm(a-b))')
plt.show()
plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1), log=True)
plt.show()

im = (a-b)
im = scale(im)
# im = norm(im)
imshow(im[j, l])
plt.title('scale(a-b)')
plt.show()
plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1), log=True)
plt.show()

im = (norm(a)-norm(b))
im = scale(im)
# im = norm(im)
imshow(im[j, l])
plt.title('scale(norm(a)-norm(b))')
plt.show()
plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1), log=True)
plt.show()

im = a-b
# im = scale(im)
im = norm(im)
imshow(im[j, l])
plt.title('norm(a-b)')
plt.show()
plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1), log=True)
plt.show()

plt.title('norm(scale(a)-scale(b))')
im = symlog(a)-symlog(b)
im /= np.nanmax(np.abs(im))
imshow(im[j, l])
plt.show()

plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1))
plt.show()

In [None]:


im = a-b
# im = norm(im)
# im = scale(im)
im = norm_h(im)
# im /= np.nanmax(np.abs(im), axis=-2, keepdims=True)
imshow(im[j, l])
plt.show()

plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1))
plt.show()

In [None]:
im = symlog(a)-symlog(b)
# im = norm(im)
# im = scale(im)
im = norm_h(im)
# im /= np.nanmax(np.abs(im), axis=-2, keepdims=True)
imshow(im[j, l])
plt.show()

plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1))
plt.show()

In [None]:
im = symlog(norm_h(a))-symlog(norm_h(b))
# im = norm(im)
# im = scale(im)
im = norm_h(im)
# im /= np.nanmax(np.abs(im), axis=-2, keepdims=True)
imshow(im[j, l])
plt.show()

plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1))
plt.show()

In [None]:
im = symlog(a)-symlog(b)
im = norm_h(im)
imshow(im[j, l])
plt.show()

plt.hist(im.flatten(), bins=55, alpha=0.5, histtype='step', range=(-1, 1))
plt.show()

In [None]:
# # x = np.abs(im[j, l]).mean(-1)
# # plt.plot(x)
# # x[x>0].argmin()
# plt.hist(im2.flatten(), bins=55, alpha=0.5, histtype='step')
# plt.show()

In [None]:
bs = batch_concat['concatenated_input_ids'].shape[0]//2
a = batch_concat['concatenated_input_ids'][:bs]
b = batch_concat['concatenated_input_ids'][bs:]
a = tokenizer.batch_decode(a[j])
b = tokenizer.batch_decode(b[j])
# c = tokenizer.batch_decode(batch['prompt_input_ids'][j])
x = np.abs(im[j, l]).mean(-1)
r = list(zip(
    # range(len(a)),
    x,
    a,
    b
))

import pandas as pd
df = pd.DataFrame(r, columns=['x', 'tok_cho', 'tok_rej'])
df[df.x>0].sort_values('x')
df[df.x>0]


In [None]:
batch_concat.keys()

In [None]:
plt.figure(figsize=(10, 8))
imshow(im2)
plt.colorbar(location='top')

# Evals