Experiment to use lora to make a lying model. Here we think of Lora as a probe, as it acts in a very similar way - modifying the residual stream.

Then the hope is it will assist at lie detecting and generalize to unseen dataset

- https://github.dev/JD-P/minihf/blob/b54075c34ef88d9550e37fdf709e78e5a68787c4/lora_tune.py
- https://github.com/jonkrohn/NLP-with-LLMs

In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange

import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoConfig,
)
from peft import (
    get_peft_config,
    get_peft_model,
    LoraConfig,
    TaskType,
    LoftQConfig,
    IA3Config,
)

import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


# # quiet please
torch.set_float32_matmul_precision("medium")
import warnings

warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings(
#     "ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*"
# )
# warnings.filterwarnings("ignore", ".*has been removed as a dependency of.*")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl
from src.datasets.dm import DeceptionDataModule
from src.models.pl_lora_ft import AtapterFinetuner

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset, load_preproc_datasets
from src.models.load import load_model
from src.helpers.torch import clear_mem
from src.models.phi.model_phi import PhiForCausalLMWHS


## Parameters


In [3]:
# params
max_epochs = 2
device = "cuda:0"

cfg = ExtractConfig(
    max_examples=(300, 100),
    # model="wassname/phi-1_5-w_hidden_states",
    # batch_size=3,
    # model="wassname/phi-2-w_hidden_states",
    model="Walmart-the-bag/phi-2-uncensored",
    batch_size=1,
    prompt_format="phi",
)


## Load model

In [4]:
model, tokenizer = load_model(
    cfg.model,
    device=device,
    model_class=PhiForCausalLMWHS,
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.22s/it]


In [5]:
# TODO I would like to only have biases, but for now lets just try a very small intervention on the last parts of a layer...
peft_config = LoraConfig(
    target_modules=[
        "out_proj",
        "mlp.fc2",
        # "mlp.fc1",
        "Wqkv",
        # 'inner_attn',
        # 'inner_cross_attn',
    ],  # only the layers that go directly to the residual
    # bias="lora_only",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=4,
    lora_alpha=8,
    lora_dropout=0.0,
)


# peft_config = IA3Config(
#     task_type=TaskType.SEQ_CLS, target_modules=[ "out_proj",
#         "mlp.fc2",], feedforward_modules=["out_proj", "mlp.fc2",]
# )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


trainable params: 3,604,480 || all params: 2,783,288,320 || trainable%: 0.12950436985270716


In [6]:
model


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PhiForCausalLMWHS(
      (transformer): PhiModel(
        (embd): Embedding(
          (wte): Embedding(51200, 2560)
          (drop): Dropout(p=0.0, inplace=False)
        )
        (h): ModuleList(
          (0-31): 32 x ParallelBlock(
            (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
            (resid_dropout): Dropout(p=0.1, inplace=False)
            (mixer): MHA(
              (rotary_emb): RotaryEmbedding()
              (Wqkv): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2560, out_features=7680, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=4, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=4, out_features=7680, bias=False)
        

## Load datasets

In [7]:
N = sum(cfg.max_examples)
ds_tokens = load_preproc_datasets(
    cfg.datasets,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens


[32m2023-12-23 10:05:38.496[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m350[0m - [1msetting tokenizer chat template to phi[0m
2023-12-23T10:05:38.496166+0800 INFO setting tokenizer chat template to phi
[32m2023-12-23 10:05:38.565[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m381[0m - [1mmedian token length: 443.0 for amazon_polarity. max_length=776[0m
2023-12-23T10:05:38.565986+0800 INFO median token length: 443.0 for amazon_polarity. max_length=776
[32m2023-12-23 10:05:38.567[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m385[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
2023-12-23T10:05:38.567051+0800 INFO truncation rate: 0.00% on amazon_polarity
[32m2023-12-23 10:05:38.575[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m394[0m - [1mnum_rows (after filtering out tru

Dataset({
    features: ['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'],
    num_rows: 402
})

In [8]:
ds_tokens2 = load_preproc_datasets(
    cfg.datasets_oos,
    tokenizer,
    N=N // 2,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)
ds_tokens2


[32m2023-12-23 10:05:38.754[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m381[0m - [1mmedian token length: 298.5 for glue:qnli. max_length=776[0m
2023-12-23T10:05:38.754225+0800 INFO median token length: 298.5 for glue:qnli. max_length=776
[32m2023-12-23 10:05:38.755[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m385[0m - [1mtruncation rate: 0.00% on glue:qnli[0m
2023-12-23T10:05:38.755320+0800 INFO truncation rate: 0.00% on glue:qnli
[32m2023-12-23 10:05:38.762[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m394[0m - [1mnum_rows (after filtering out truncated rows) 604=>604[0m
2023-12-23T10:05:38.762037+0800 INFO num_rows (after filtering out truncated rows) 604=>604


Dataset({
    features: ['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'],
    num_rows: 201
})

## Train

In [9]:
dm = DeceptionDataModule(ds_tokens, batch_size=cfg.batch_size)
dm


<src.datasets.dm.DeceptionDataModule at 0x7fb2db3a5c10>

In [10]:
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()


In [11]:
b = next(iter(dl_train))
print(b.keys(), b["input_ids"].shape)
c_in = b["input_ids"].shape[1]
c_in


dict_keys(['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids']) torch.Size([1, 776])


776

## custom models

In [12]:
# from src.models.pl_lora_ft import AtapterFinetuner, select_choices

# class AtapterFinetunerLie(AtapterFinetuner):
#     def get_loss(self, batch, out, out_a):
#         """
#         simply train it to lie
#         """

#         log_probs_a = torch.log_softmax(out_a["logits"][:, -1,], -1,)

#         # batch['instructed_to_lie']
#         lie_label = ~batch['label_true']
#         batch_size = lie_label.shape[0]
#         batch_inds = torch.arange(batch_size).long().unsqueeze(1)
#         # choice_ids1 = batch['choice_ids'][:, lie_label, 0]
#         choice_ids1 = batch['choice_ids'][:, :, 0][batch_inds, lie_label.long()].squeeze(1)
#         # choice_ids2 = batch['choice_ids'][:, lie_label, 1]
#         choice_ids2 = batch['choice_ids'][:, :, 1][batch_inds, lie_label.long()].squeeze(1)
#         # choice_ids = batch['choice_ids'][torch.arange(1).long().unsqueeze(1), lie_label.long()]
#         loss1 = F.nll_loss(log_probs_a, target=choice_ids1)
#         loss2 = F.nll_loss(log_probs_a, target=choice_ids2)
#         loss = (loss1 + loss2) / 2

#         return loss, None, None


In [13]:
from src.models.pl_lora_ft import AtapterFinetuner, select_choices


class AtapterFinetunerToldToLie(AtapterFinetuner):
    def get_loss(self, batch, out, out_a):
        """
        simply train it to lie
        """

        log_probs_a = torch.log_softmax(
            out_a["logits"][
                :,
                -1,
            ],
            -1,
        )

        lie_label = batch["label_true"] ^ batch["instructed_to_lie"]
        batch_size = lie_label.shape[0]
        batch_inds = torch.arange(batch_size).long().unsqueeze(1)
        choice_ids1 = batch["choice_ids"][:, :, 0][
            batch_inds, lie_label.long()
        ].squeeze(1)
        choice_ids2 = batch["choice_ids"][:, :, 1][
            batch_inds, lie_label.long()
        ].squeeze(1)
        # choice_ids = batch['choice_ids'][torch.arange(1).long().unsqueeze(1), lie_label.long()]
        loss1 = F.nll_loss(log_probs_a, target=choice_ids1)
        loss2 = F.nll_loss(log_probs_a, target=choice_ids2)
        loss = (loss1 + loss2) / 2

        return loss, None, None


In [14]:
net = AtapterFinetunerToldToLie(
    model, tokenizer, lr=5e-3, weight_decay=1e-3, total_steps=len(dl_train) * max_epochs
)

print(c_in)
# net.model.enable_adapters()


776


In [15]:
# # debug
# with torch.no_grad():
#     o = net.training_step(b, None)
# o


In [16]:
# # debug
# with torch.no_grad():
#     o = net.predict_step(b, None)
# o.keys()


In [17]:
# we want to init lightning early, so it inits accelerate
trainer1 = pl.Trainer(
    # precision="16-true", # leads to inf loss?
    # precision="16-mixed", # works
    # precision="bf16-mixed",
    gradient_clip_val=20,
    # accelerator="auto",
    devices="1",
    accelerator="gpu",
    # devices=[0],
    accumulate_grad_batches=4,
    max_epochs=max_epochs,
    log_every_n_steps=1,
    enable_model_summary=False,
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [18]:
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Epoch 0:   0%|          | 0/201 [00:00<?, ?it/s]                           

OutOfMemoryError: CUDA out of memory. Tried to allocate 76.00 MiB. GPU 0 has a total capacty of 23.67 GiB of which 1.17 GiB is free. Process 3254412 has 4.14 GiB memory in use. Process 3618154 has 5.63 GiB memory in use. Including non-PyTorch memory, this process has 11.34 GiB memory in use. Of the allocated memory 10.88 GiB is allocated by PyTorch, and 159.01 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
checkpoint_path = Path(trainer1.log_dir) / "final"
model.save_pretrained(checkpoint_path)
checkpoint_path


In [None]:
from src.helpers.lightning import read_metrics_csv

df_histe, df_hist = read_metrics_csv(trainer1.logger.experiment.metrics_file_path)
df_hist[["train/loss_step", "val/loss_step"]].plot(style=".")
df_hist


In [None]:
df_histe.drop(columns=["step"]).plot()


## Generate


In [None]:
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16,  # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
clear_mem()


In [None]:
# get a row
bi = 4
inputs = ds_tokens.with_format("torch")[bi]

from src.eval.gen import gen


In [None]:
with model.disable_adapter():
    gen(model, inputs, tokenizer)

gen(model, inputs, tokenizer)


# Test

In [None]:
from src.eval.helpers import test_intervention_quality2
from src.eval.labels import ds2label_model_obey, ds2label_model_truth


In [None]:
dm2 = DeceptionDataModule(ds_tokens2, batch_size=cfg.batch_size* 2)
dl_train2 = dm2.train_dataloader()
dl_train2.shuffle = False
dl_val2 = dm2.val_dataloader()
dl_test2 = dm2.test_dataloader()


In [None]:
dl_valtest2 = DataLoader(
    torch.utils.data.ConcatDataset([dm.datasets["val"], dm.datasets["test"]]),
     batch_size=cfg.batch_size * 2,
)
len(dl_valtest2.dataset)


In [None]:
dl_oos2 = DataLoader(
    ds_tokens2, batch_size=cfg.batch_size * 2, drop_last=False, shuffle=False
)
len(dl_oos2.dataset)


In [None]:
import re

def transform_dl_k(k: str) -> str:
    p = re.match(r"test\/(.+)\/dataloader_idx_\d", k)
    return p.group(1) if p else k


def rename(rs, ks=["train", "val", "test"]):
    rs = {
        ks[i]: {transform_dl_k(k): v for k, v in rs[i].items()} for i in range(len(ks))
    }
    return rs

rs = trainer1.test(
    net,
    dataloaders=[
        dl_train2, dl_val2,
        dl_test2,
        dl_oos2,
    ],
)
rs = rename(rs, ["train", "val", "test", "oos"])
rs[0]


In [None]:
%debug


# Predict

Here we want to see if we can do a probe on the hidden states to see if it's lying...


### Collect

- see how acc each was for instructions vs truth
- see how a linear probe trained on the diff can do for truth, vs baseline

In [None]:
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16,  # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
clear_mem()


In [None]:
from src.eval.collect import manual_collect2
from src.eval.ds import filter_ds_to_known
from src.eval.labels import LABEL_MAPPING
from src.eval.ds import qc_ds, ds2df, qc_dsdf


In [None]:
ds_out_oos, f = manual_collect2(dl_oos2, model, dataset_name="oos")
ds_out_valtest, f = manual_collect2(dl_valtest2, model, dataset_name="oos")


### Eval

In [None]:
def analyse_intervention(ds_out, tokenizer):
    ds_known = filter_ds_to_known(ds_out, verbose=True)

    print(
        f"🥇 primary metric: predictive power (of logistic regression on top of intervened hidden states of known question)"
    )
    print(
        f"""
    The roc_auc should go up on the right given the intervented states
    """
    )
    for label_name, label_fn in LABEL_MAPPING.items():
        try:
            # fit probe
            # print('='*80)
            print(f"predicting label={label_name}")
            df_res = test_intervention_quality2(ds_known, label_fn, tokenizer)
            display(df_res)
        except Exception as e:
            print(f"Exception {e}")

    df1 = ds2df(ds_out)
    df_b = df1.rename(columns=lambda x: x.replace("_base", "")).copy()
    res_b = qc_dsdf(df_b)
    df_a = df1.rename(columns=lambda x: x.replace("_adapt", "")).copy()
    res_a = qc_dsdf(df_a)
    df_res_ab = pd.DataFrame([res_b, res_a], index=["base", "adapter"]).T
    print("🥉 secondary metric: dataset quality: performance of base model and adapter")
    display(df_res_ab)


In [None]:
print('valtest')
analyse_intervention(ds_out_valtest, tokenizer)

print('out of sample')
analyse_intervention(ds_out_oos, tokenizer)



### Check dataset of outputs

In [None]:
df = ds2df(ds_out)
df


In [None]:
df1 = ds2df(ds_out)
df_b = df1.rename(columns=lambda x: x.replace("_base", "")).copy()
res_b = qc_dsdf(df_b)
df_a = df1.rename(columns=lambda x: x.replace("_adapt", "")).copy()
res_a = qc_dsdf(df_a)
df_res_ab = pd.DataFrame([res_b, res_a], index=["base", "adapter"]).T
print("model performance")
display(df_res_ab)


In [None]:
print("acc by dataset and template name: base")
df1 = ds2df(ds_out)
df_b = df1.rename(columns=lambda x: x.replace("_base", "")).copy()
for ds_string, ddf in df_b.groupby(["ds_string", "template_name"]):
    print(ds_string)
    qc_dsdf(ddf)


In [None]:
print("acc by dataset and template name: adapter")
df1 = ds2df(ds_out)
df_a = df1.rename(columns=lambda x: x.replace("_adapt", "")).copy()
for ds_string, ddf in df_a.groupby(["ds_string", "template_name"]):
    print(ds_string)
    qc_dsdf(ddf)
