Experiment to use lora to make a lying model. Here we think of Lora as a probe, as it acts in a very similar way - modifying the residual stream.

Then the hope is it will assist at lie detecting and generalize to unseen dataset

- https://github.dev/JD-P/minihf/blob/b54075c34ef88d9550e37fdf709e78e5a68787c4/lora_tune.py
- https://github.com/jonkrohn/NLP-with-LLMs

In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]= "1"


In [2]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType, LoftQConfig, IA3Config

import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


# # quiet please
torch.set_float32_matmul_precision("medium")
import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings(
#     "ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*"
# )
# warnings.filterwarnings("ignore", ".*has been removed as a dependency of.*")


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl
from src.datasets.dm import DeceptionDataModule
from src.models.pl_lora_ft import AtapterFinetuner

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset
from src.models.load import load_model


## Parameters


In [None]:
# params
max_epochs = 1
device = "cuda:0"

cfg = ExtractConfig(
    batch_size=3,
    max_examples=(800, 60),
)


## Load model

In [None]:
model, tokenizer = load_model(
    cfg.model,
    device=device,
)


In [None]:
# TODO I would like to only have biases, but for now lets just try a very small intervention on the last parts of a layer...
peft_config = LoraConfig(
    target_modules=[
        "out_proj",
        "mlp.fc2",
    ],  # only the layers that go directly to the residual
    # bias="lora_only",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=4,
    lora_alpha=1,
    lora_dropout=0.0,
)


# peft_config = IA3Config(
#     task_type=TaskType.SEQ_CLS, target_modules=[ "out_proj",
#         "mlp.fc2",], feedforward_modules=["out_proj", "mlp.fc2",]
# )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


In [None]:
N = sum(cfg.max_examples)
ds_name = "amazon_polarity"
ds_tokens = load_preproc_dataset(
    ds_name,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
).with_format("torch")


## Train

In [None]:
dm = DeceptionDataModule(ds_tokens, batch_size=cfg.batch_size)
dm


In [None]:
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()


In [None]:
b = next(iter(dl_train))
print(b.keys(), b["input_ids"].shape)
c_in = b["input_ids"].shape[1]
c_in


In [None]:
net = AtapterFinetuner(
    model, tokenizer, lr=5e-5, weight_decay=0, total_steps=len(dl_train) * max_epochs
)

print(c_in)
# net.model.enable_adapters()


In [None]:
# debug
with torch.no_grad():
    o = net.training_step(b, None)
o


In [None]:
# debug
with torch.no_grad():
    o = net.predict_step(b, None)
# o


In [None]:
# we want to init lightning early, so it inits accelerate
trainer1 = pl.Trainer(
    # precision="16-true", # works?
    # precision="16-mixed",
    # precision="b16-mixed",
    gradient_clip_val=20,
    # accelerator="auto",
    # devices="1",
    # accelerator="gpu",
    # devices=[0],
    # accumulate_grad_batches=2,
    max_epochs=max_epochs,
    log_every_n_steps=1,
    enable_model_summary=False,
)


In [None]:
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);


In [None]:
checkpoint_path = Path(trainer1.log_dir)/'final'
model.save_pretrained(checkpoint_path)


In [None]:
from src.helpers.lightning import read_metrics_csv

pd.read_csv(trainer1.logger.experiment.metrics_file_path).bfill().ffill()


## Generate


In [None]:
# get a row
bi = 4
inputs = ds_tokens.with_format("torch")[bi]


In [None]:
from IPython.display import display, HTML

# generate
# https://huggingface.co/docs/transformers/v4.34.1/en/main_classes/text_generation#transformers.GenerationConfig


@torch.no_grad()
def gen(model):
    s = model.generate(
        input_ids=inputs["input_ids"][None, :].to(model.device),
        attention_mask=inputs["attention_mask"][None, :].to(model.device),
        use_cache=False,
        max_new_tokens=100,
        min_new_tokens=100,
        do_sample=False,
        early_stopping=False,
    )
    input_l = inputs["input_ids"].shape[0]
    old = tokenizer.decode(
        s[0, :input_l], clean_up_tokenization_spaces=False, skip_special_tokens=False
    )
    new = tokenizer.decode(
        s[0, input_l:], clean_up_tokenization_spaces=False, skip_special_tokens=False
    )
    display(HTML(f"<pre>{old}</pre><b><pre>{new}</pre></b>"))


In [None]:
# for some reason the trainer adds accelerate hooks that mess it up, lets load from scratch
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    # bnb=False
)
net = AtapterFinetuner(
    model, tokenizer, lr=5e-5, weight_decay=0, total_steps=len(dl_train) * max_epochs
)


In [None]:
# with model.disable_adapters():
with model.disable_adapter():
    gen(model)

gen(model)


# Test

In [None]:
N = sum(cfg.max_examples)
ds_name = "imdb"
ds_tokens2 = load_preproc_dataset(
    ds_name,
    tokenizer,
    N=N // 4,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
).with_format("torch")


In [None]:
dm = DeceptionDataModule(ds_tokens, batch_size=cfg.batch_size * 2)
dl_train2 = dm.train_dataloader()
dl_val2 = dm.val_dataloader()
dl_test2 = dm.test_dataloader()


In [None]:
dl_oos2 = DataLoader(
    ds_tokens2, batch_size=cfg.batch_size * 2, drop_last=False, shuffle=False
)
len(ds_tokens2)


In [None]:
# rs = trainer1.test(
#     net,
#     dataloaders=[
#         # dl_train2, dl_val2,
#         dl_test2,
#         dl_oos2,
#     ],
# )
# rs = rename(rs, ["train", "val", "test", "oos"])
# rs[0]


# Predict

Here we want to see if we can do a probe on the hidden states to see if it's lying...


now
- see how acc each was for instructions vs truth
- see how a linear probe trained on the diff can do for truth, vs baseline

In [None]:

import sklearn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.linear_model import LogisticRegression
from einops import rearrange
from sklearn.preprocessing import StandardScaler

def check_intervention_predictive(hs, y):
    """
    We want the hidden states resulting from interventions to have predictive power
    Lets compare normal hidden states to intervened hidden states
    """
    X = rearrange(hs, 'b l hs -> b (l hs)')
    N = len(X)//2
    X_train, X_val = X[:N], X[N:]
    y_train, y_val = y[:N], y[N:]

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)

    clf = LogisticRegression(random_state=42, max_iter=1000, class_weight='balanced',).fit(X_train, y_train)
    y_pred = clf.predict(X_train)
    y_val_pred = clf.predict(X_val)
    score = roc_auc_score(y_val, y_val_pred)
    return score


def test_intervention_quality2(ds_out, label_fn, thresh=0.03, take_diff=False):
    """
    Check interventions are ordered and different and valid

    TODO better metrics
    - primary metric: **predictive** or a linear classifier on top of intervention hidden states can predict my labels
    - debug metric: **significant** it's not just a small change
    - debug metric: **coherent** 
        - it's not just outputting nonsense, 
        - or just "yes", the choices keep coverage, 
        - it's not over confident
    """

    # collect labels    
    label = label_fn(ds_out)

    # collect hidden states
    hs_normal = ds_out['end_hidden_states_base']
    hs_intervene = ds_out['end_hidden_states_adapt']
    if take_diff:
        print("taking diff")
        hs_normal = hs_normal.diff(1)
        hs_intervene = hs_intervene.diff(1)

    print("primary metric: predictive power (of logistic regression on top of intervened hidden states)")
    s1_baseline = check_intervention_predictive(hs_normal, label)
    s1_interven = check_intervention_predictive(hs_intervene, label)
    predictive = s1_interven - s1_baseline > thresh
    print(f"predictive power? {predictive} [i] = baseline: {s1_baseline:.3f} > {s1_interven:.3f} roc_auc")
    s1_interven = check_intervention_predictive(hs_intervene-hs_normal, label)
    predictive = s1_interven - s1_baseline > thresh
    print(f"predictive power? {predictive} [i-b] = baseline: {s1_baseline:.3f} > {s1_interven:.3f} roc_auc")

    s1_baseline = check_intervention_predictive(hs_normal.diff(1), label)
    s1_interven = check_intervention_predictive(hs_intervene.diff(1), label)
    predictive = s1_interven - s1_baseline > thresh
    print(f"predictive power? {predictive} [diff]  = baseline: {s1_baseline:.3f} > {s1_interven:.3f} roc_auc")
    s1_interven = check_intervention_predictive((hs_intervene-hs_normal).diff(1), label)
    predictive = s1_interven - s1_baseline > thresh
    print(f"predictive power? {predictive} [diff(i-b)] = baseline: {s1_baseline:.3f} > {s1_interven:.3f} roc_auc")

    # also check coverage
    # also check reasonable probs (e.g choices not too high, others not too low)
    # also check the probs actually makes a differen't to ans
    # We would hope that an unrelated tokens would have it's probability mostly uneffected
    id_unrelated = tokenizer.encode('\n')[0]
    unrelated_probs_a = torch.softmax(ds_out['end_logits_adapt'], 0)[:, id_unrelated].mean(0).item()
    unrelated_probs_b = torch.softmax(ds_out['end_logits_base'], 0)[:, id_unrelated].mean(0).item()
    df_metrics = pd.DataFrame({
        'coverage': [
            ds_out['choice_probs_base'].mean(0).sum(0).item(),
            ds_out['choice_probs_adapt'].mean(0).sum(0).item(),
        ],
        'ans': [
            ds_out['binary_ans_base'].mean(0).item(),
            ds_out['binary_ans_adapt'].mean(0).item()
            ],
        'unrelated_probs': [unrelated_probs_a, unrelated_probs_b],
    }, index=['baseline', 'intervene']).T
    display(df_metrics)



In [None]:
def ds2label_model_obey(ds):
    """extract label from hs dataset, for cases where model obeys instructions (wether to lie or not)"""
    label_instructed = ds["label_true_base"] ^ ds["instructed_to_lie_base"]
    ans = ds["binary_ans_base"] > 0.5
    labels_untruth = label_instructed == ans
    return labels_untruth


def ds2label_model_truth(ds):
    ans = ds["binary_ans_base"] > 0.5
    labels_true_ans = ds["label_true_base"] == ans
    return labels_true_ans


In [None]:
accelerator = trainer1.accelerator
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
)
net = AtapterFinetuner(
    model, tokenizer, lr=5e-5, weight_decay=0, total_steps=len(dl_train) * max_epochs
)

rv = trainer1.predict(net, dataloaders=dl_oos2)
# convert from List[Dict[Tensor] to Dict[Tensor]
ds_out = Dataset.from_dict({k: torch.concat([rr[k] for rr in rv]) for k in rv[0].keys()}).with_format("torch")
ds_out


In [None]:
for label_name, label_fn in dict(label_model_truth=ds2label_model_truth, label_model_obey=ds2label_model_obey).items():
    # fit probe
    print('='*80)
    print('making intervention with', label_name, 'hidden states')
    test_intervention_quality2(ds_out, label_fn)

for label_name, label_fn in dict(label_model_truth=ds2label_model_truth, label_model_obey=ds2label_model_obey).items():
    # fit probe
    print('='*80)
    print('making intervention with', label_name, 'diff(hidden states)')
    test_intervention_quality2(ds_out, label_fn, take_diff=True)


In [None]:
def filter_ds_to_known(ds1, verbose=True):
    """filter the dataset to only those where the model knows the answer"""
    
    # first get the rows where it answered the question correctly
    df = ds2df(ds1)
    d = df.query('sys_instr_name=="truth"').set_index("example_i")
    m1 = d.llm_ans==d.label_true
    known_indices = d[m1].index
    known_rows = df['example_i'].isin(known_indices)
    known_rows_i = df[known_rows].index
    
    if verbose: print(f"select rows are {m1.mean():2.2%} based on knowledge")
    return ds1.select(known_rows_i)


In [None]:
def rows_item(row):
    """
    transform a row by turning singe dim arrays into items
    """
    for k,x in row.items():
        if isinstance(x, np.ndarray) and (x.ndim==0 or (x.ndim==1 and len(x)==1)):
            row[k]=x.item()
        if isinstance(x, list) and len(x)==1:
            row[k]=x[0]
    return row


def ds2df(ds, cols=None):
    """one of our custom datasets into a dataframe
    
    dropping the large arrays and lists"""
    
    # json.loads(dss[0].info.description)['f'] # doesn't work when concat

    if cols is None:
        r = ds[0]
        # get all the columns that not large lists or arrays
        cols = [k for k,v in r.items() if (isinstance(v, np.ndarray) and v.size<2) or not isinstance(v, (list, np.ndarray))]
    ds = ds.with_format('numpy')
    df = ds.select_columns(cols)
    df = pd.DataFrame([rows_item(r) for r in df])
    return df

def qc_ds(ds):
    df = ds2df(ds.with_format('numpy')).rename(columns=lambda x: x.replace('_base', ''))


    df['label_instructed'] = df['label_true'] ^ df['instructed_to_lie']


    # check llm accuracy
    d = df.query('instructed_to_lie==False')
    acc = (d.label_instructed==d.binary_ans).mean()
    assert np.isfinite(acc)
    print(f"\tacc    =\t{acc:2.2%} [N={len(d)}] - when the model is not lying... we get this task acc")
    assert acc>0.3, "model cannot solve task"

    # check LLM lie freq
    d = df.query('instructed_to_lie==True')
    acc = (d.label_instructed==d.binary_ans).mean()
    assert np.isfinite(acc)
    print(f"\tlie_acc=\t{acc:2.2%} [N={len(d)}] - when the model tries to lie... we get this acc")
    assert acc>0.01, "no known lies"

    # check LLM lie freq
    ds_known = filter_ds_to_known(ds, verbose=False)
    df_known = ds2df(ds_known)
    d = df_known.query('instructed_to_lie==True')
    acc = (d.label_instructed==d.binary_ans).mean()
    assert np.isfinite(acc)
    print(f"\tknown_lie_acc=\t{acc:2.2%} [N={len(d)}] - when the model tries to lie and knows the answer... we get this acc")
    assert acc>0.01, "no known lies"

    # check choice coverage
    mean_prob = np.sum(ds['choice_probs'], 1).mean()
    print(f"\tchoice_cov=\t{mean_prob:2.2%} - Our choices accounted for a mean probability of this")
    assert mean_prob>0.1, "neither of the available choice very likely :(, try debuging your templates. Check: using the correct prompt, the whitespace is correct, the correct eos_tokens (if any)"


In [None]:
qc_ds(ds_out)


In [None]:
assert (ds_out['binary_ans_base']-ds_out['binary_ans_adapt']).abs().mean()>0.1, 'should be a larger diff'
