Experiment to use lora to make a lying model. Here we think of Lora as a probe, as it acts in a very similar way - modifying the residual stream.

Then the hope is it will assist at lie detecting and generalize to unseen dataset

- https://github.dev/JD-P/minihf/blob/b54075c34ef88d9550e37fdf709e78e5a68787c4/lora_tune.py
- https://github.com/jonkrohn/NLP-with-LLMs

In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange

import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoConfig,
)
from peft import (
    get_peft_config,
    get_peft_model,
    LoraConfig,
    TaskType,
    LoftQConfig,
    IA3Config,
)

import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


# # quiet please
torch.set_float32_matmul_precision("medium")
import warnings

warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings(
#     "ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*"
# )
# warnings.filterwarnings("ignore", ".*has been removed as a dependency of.*")


In [2]:
# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl
from src.datasets.dm import DeceptionDataModule
from src.models.pl_lora_ft import AtapterFinetuner

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset, load_preproc_datasets
from src.models.load import load_model
from src.helpers.torch_helpers import clear_mem
from src.models.phi.model_phi import PhiForCausalLMWHS


## Parameters


In [3]:
# params
max_epochs = 2
device = "cuda:0"

cfg = ExtractConfig(
    max_examples=(1600, 1600),
    # model="wassname/phi-1_5-w_hidden_states",
    # batch_size=3,
    # model="wassname/phi-2-w_hidden_states",
    model="microsoft/phi-2",
    # model="microsoft/phi-1_5",
    # model="Walmart-the-bag/phi-2-uncensored",
    batch_size=1,
    prompt_format="phi",
)


## Load model

In [4]:
model, tokenizer = load_model(
    cfg.model,
    device=device,
    model_class=PhiForCausalLMWHS, # ti add hidden states
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# TODO I would like to only have biases, but for now lets just try a very small intervention on the last parts of a layer...
peft_config = LoraConfig(
    target_modules=[
        "out_proj",
        "mlp.fc2",
        
        # "mlp.fc1",
        "Wqkv",
        # 'inner_attn',
        # 'inner_cross_attn',
    ],  # only the layers that go directly to the residual
    # bias="lora_only",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=3,
    lora_alpha=6,
    lora_dropout=0.0,
)


# peft_config = IA3Config(
#     task_type=TaskType.SEQ_CLS, target_modules=[ "out_proj",
#         "mlp.fc2",], feedforward_modules=["out_proj", "mlp.fc2",]
# )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


trainable params: 2,703,360 || all params: 2,782,387,200 || trainable%: 0.09715973391481962


## Load datasets

In [6]:
assert len(set(cfg.datasets).intersection(cfg.datasets_ood))==0, "datasets overlap"


In [7]:
N = sum(cfg.max_examples)
ds_tokens = load_preproc_datasets(
    cfg.datasets,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens


Generating train split: 0 examples [00:00, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)
[32m2023-12-28 08:11:37.608[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_prompts[0m:[36m119[0m - [1mExtracting 11 variants of each prompt[0m
2023-12-28T08:11:37.608252+0800 INFO Extracting 11 variants of each prompt
[32m2023-12-28 08:15:37.278[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m359[0m - [1msetting tokenizer chat template to phi[0m
2023-12-28T08:15:37.278799+0800 INFO setting tokenizer chat template to phi


format_prompt:   0%|          | 0/1924 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/1924 [00:00<?, ? examples/s]

truncated:   0%|          | 0/1924 [00:00<?, ? examples/s]

truncated:   0%|          | 0/1924 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/1924 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/1924 [00:00<?, ? examples/s]

[32m2023-12-28 08:15:43.787[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m390[0m - [1mmedian token length: 397.0 for amazon_polarity. max_length=776[0m
2023-12-28T08:15:43.787190+0800 INFO median token length: 397.0 for amazon_polarity. max_length=776
[32m2023-12-28 08:15:43.788[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m394[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
2023-12-28T08:15:43.788711+0800 INFO truncation rate: 0.00% on amazon_polarity


Filter:   0%|          | 0/1924 [00:00<?, ? examples/s]

ValueError: Column name sys_instr_name_base not in the dataset. Current columns in the dataset: ['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'].

In [None]:
ds_tokens2 = load_preproc_datasets(
    cfg.datasets_ood,
    tokenizer,
    N=N // 2,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens2


## custom models

In [None]:
from src.models.pl_lora_ft import AtapterFinetuner
from src.helpers.scores import select

class AtapterFinetunerLie(AtapterFinetuner):
    def get_loss(self, batch, out, out_a):
        """
        simply train it to lie
        """

        log_probs_a = torch.log_softmax(out_a["logits"][:, -1,], -1,)

        # batch['instructed_to_lie']
        lie_label = ~batch['label_true']
        choice_ids1 = select(batch["choice_ids"][:, :, 0], lie_label.long())
        choice_ids2 = select(batch["choice_ids"][:, :, 1], lie_label.long())
        loss1 = F.nll_loss(log_probs_a, target=choice_ids1)
        loss2 = F.nll_loss(log_probs_a, target=choice_ids2)
        loss = (loss1 + loss2) / 2

        return loss, None, None


In [None]:
from src.models.pl_lora_ft import AtapterFinetuner
from src.helpers.scores import select


class AtapterFinetunerToldToLie(AtapterFinetuner):
    def get_loss(self, batch, out, out_a):
        """
        train it to lie when instructed
        """

        end_logits = out_a["logits"][
            :,
            -1,
        ]
        log_probs_a = torch.log_softmax(end_logits, -1)

        lie_label = batch["label_true"] ^ batch["instructed_to_lie"]
        choice_ids1 = select(batch["choice_ids"][:, :, 0], lie_label.long())
        choice_ids2 = select(batch["choice_ids"][:, :, 1], lie_label.long())
        loss1 = F.nll_loss(log_probs_a, target=choice_ids1)
        loss2 = F.nll_loss(log_probs_a, target=choice_ids2)
        loss = (loss1 + loss2) / 2

        return loss, None, None


In [None]:
from src.models.pl_lora_ft import AtapterFinetuner
from src.helpers.scores import select


class AtapterFinetunerTruth(AtapterFinetuner):
    def get_loss(self, batch, out, out_a):
        """
        train it to lie when instructed
        """

        end_logits = out_a["logits"][
            :,
            -1,
        ]
        log_probs_a = torch.log_softmax(end_logits, -1)

        lie_label = batch["label_true"] #^ batch["instructed_to_lie"]
        choice_ids1 = select(batch["choice_ids"][:, :, 0], lie_label.long())
        choice_ids2 = select(batch["choice_ids"][:, :, 1], lie_label.long())
        loss1 = F.nll_loss(log_probs_a, target=choice_ids1)
        loss2 = F.nll_loss(log_probs_a, target=choice_ids2)
        loss = (loss1 + loss2) / 2

        return loss, None, None


In [None]:
model_cls = AtapterFinetunerToldToLie


## Train

In [None]:
dm = DeceptionDataModule(ds_tokens, batch_size=cfg.batch_size)
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()


In [None]:
b = next(iter(dl_train))
print(b.keys(), b["input_ids"].shape)
c_in = b["input_ids"].shape[1]
c_in


In [None]:
net = model_cls(
    model, tokenizer, lr=5e-3, weight_decay=1e-5, total_steps=len(dl_train) * max_epochs
)

print(c_in)


In [None]:
# # debug
# with torch.no_grad():
#     o = net.training_step(b, None)
# o


In [None]:
# # debug
# with torch.no_grad():
#     o = net.predict_step(b, None)
# o.keys()


In [None]:
# we want to init lightning early, so it inits accelerate
trainer1 = pl.Trainer(
    gradient_clip_val=20,
    devices="1",
    accelerator="gpu",
    accumulate_grad_batches=8,
    max_epochs=max_epochs,
    log_every_n_steps=1,
    # enable_model_summary=False,
)


In [None]:
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);


In [None]:
checkpoint_path = Path(trainer1.log_dir) / "final"
model.save_pretrained(checkpoint_path)
checkpoint_path


## Hist

In [None]:
from src.helpers.lightning import read_metrics_csv

df_histe, df_hist = read_metrics_csv(trainer1.logger.experiment.metrics_file_path)
df_hist[["train/loss_step", "val/loss_step"]].plot(style=".")
df_hist


In [None]:
df_histe[["train/loss_step", "val/loss_step"]].plot(style=".")


## Generate

This acts a QC to check of the trained adapter is still coherent while giving the opposite answer


In [None]:
from src.eval.gen import gen


In [None]:

# We need to reload it from checkpoint, since lightning seems to bug it after running
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16,  # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
clear_mem()


In [None]:
# Chose a row where we will see the difference
mask = (
    (ds_tokens['instructed_to_lie']==True) &
    (ds_tokens['label_true']==False)
).float()
bi = mask.argmax().item()

# TODO doesn't work if the model gets it wrong
inputs = ds_tokens.with_format("torch")[bi]
inputs['instructed_to_lie'], inputs['label_true']


In [None]:
with model.disable_adapter():
    gen(model, inputs, tokenizer)

gen(model, inputs, tokenizer)


# Test

In [None]:
from src.eval.interventions import test_intervention_quality2
from src.eval.labels import ds2label_model_obey, ds2label_model_truth

TEST_BATCH_MULT = 3



In [None]:
dm2 = DeceptionDataModule(ds_tokens2, batch_size=cfg.batch_size * TEST_BATCH_MULT)
dl_train2 = dm2.train_dataloader()
dl_train2.shuffle = False

dl_val2 = dm2.val_dataloader()
dl_test2 = dm2.test_dataloader()

dl_valtest2 = DataLoader(
    torch.utils.data.ConcatDataset([dm.datasets["val"], dm.datasets["test"]]),
    batch_size=cfg.batch_size * TEST_BATCH_MULT,
)
len(dl_valtest2.dataset)


In [None]:
dl_OOD = DataLoader(
    ds_tokens2, batch_size=cfg.batch_size * TEST_BATCH_MULT, drop_last=False, shuffle=False
)
len(dl_OOD.dataset)


In [None]:
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16,  # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
net = model_cls(model, tokenizer)
clear_mem()


In [None]:
from src.helpers.lightning import rename_pl_test_results

rs1 = trainer1.test(
    net,
    dataloaders=[
        dl_train2,
        dl_val2,
        dl_test2,
        dl_OOD,
    ],
    verbose=False
)
rs = rename_pl_test_results(rs1, ["train", "val", "test", "OOD"])
df_testing = pd.DataFrame(rs)
df_testing


# Predict

Here we want to see if we can do a probe on the hidden states to see if it's lying...


### Collect

- see how acc each was for instructions vs truth
- see how a linear probe trained on the diff can do for truth, vs baseline

In [None]:
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16,  # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
clear_mem()


In [None]:
from src.eval.collect import manual_collect2
from src.eval.ds import filter_ds_to_known
from src.eval.labels import LABEL_MAPPING
from src.eval.ds import qc_ds, ds2df, qc_dsdf
from src.helpers.torch_helpers import batch_to_device


In [None]:
# # for single process DEBUGING
# from src.eval.collect import generate_batches
# o = next(iter(generate_batches(dl_OOD, model)))


In [None]:
ds_out_OOD, f = manual_collect2(dl_OOD, model, dataset_name="OOD")
ds_out_valtest, f = manual_collect2(dl_valtest2, model, dataset_name="valtest")


### Eval

In [None]:
def make_dfres2_pretty(styler):
    styler.set_caption("Dataset metrics")
    styler.background_gradient(axis=1, vmin=0, vmax=1, cmap="RdYlGn", 
                               subset=['auroc', 'lie_auroc', 'known_lie_auroc', 'choice_cov']
                               )
    styler.background_gradient(axis=1, vmin=0, vmax=0.5, cmap="RdYlGn", 
                               subset=['balance']
                               )
    return styler


def analyse_intervention(ds_out, cfg, model_kwargs={}):
    ds_known = filter_ds_to_known(ds_out, verbose=True)

    print(
        f"🥇 primary metric: predictive power (of logistic regression on top of intervened hidden states of known question)"
    )
    print(
        f"""
    The roc_auc should go up on the right given the intervented states
    """
    )
    for label_name, label_fn in LABEL_MAPPING.items():
        try:
            # fit probe
            # print('='*80)
            # print(f"predicting label={label_name}")
            df_res = test_intervention_quality2(ds_known, label_fn, title=f"predicting label={label_name}",
                                                skip=cfg.skip_layers, stride=cfg.stride_layers, model_kwargs=model_kwargs)
            display(df_res)
        except Exception as e:
            raise
            print(f"Exception {e}")

    df1 = ds2df(ds_out)
    df_b = df1.rename(columns=lambda x: x.replace("_base", "")).copy()
    res_b = qc_dsdf(df_b)
    df_a = df1.rename(columns=lambda x: x.replace("_adapt", "")).copy()
    res_a = qc_dsdf(df_a)
    df_res_ab = pd.DataFrame([res_b, res_a], index=["base", "adapter"])
    print("🥉 secondary metric: dataset quality: performance of base model and adapter")
    display(df_res_ab.style.pipe(make_dfres2_pretty))
    return df_res_ab, df_res

# analyse_intervention(ds_out_OOD, tokenizer)



In [None]:
print("valtest")
df_res_ab_v, df_res_v = analyse_intervention(ds_out_valtest, cfg)

print("out of distribution")
df_res_ab_o, df_res_o = analyse_intervention(ds_out_OOD, cfg)


Hypothesis: Probes on adapter are better than either probes or adapters.

|model| val acc | OOD acc |
|--|--|--|
|base model  acc | 0.64  | 0.69 OOD |
|adapter acc | 0.65  | 0.65 |
|base+probe model residual auroc | 0.89 | 0.917|
|adapter+probe residual auroc | **0.905** | **0.974** |

So yes! Hypothesis confirmed
mm


# plot labels vs each other

to try and see why ranking is better
