Experiment to use lora to make a lying model. Here we think of Lora as a probe, as it acts in a very similar way - modifying the residual stream.

Then the hope is it will assist at lie detecting and generalize to unseen dataset

- https://github.dev/JD-P/minihf/blob/b54075c34ef88d9550e37fdf709e78e5a68787c4/lora_tune.py
- https://github.com/jonkrohn/NLP-with-LLMs

In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
from einops import rearrange

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType, LoftQConfig, IA3Config

import datasets
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


# # quiet please
torch.set_float32_matmul_precision("medium")
import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*")
# warnings.filterwarnings(
#     "ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*"
# )
# warnings.filterwarnings("ignore", ".*has been removed as a dependency of.*")


In [None]:
# load my code
%load_ext autoreload
%autoreload 2

import lightning.pytorch as pl
from src.datasets.dm import DeceptionDataModule
from src.models.pl_lora_ft import AtapterFinetuner

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset, load_preproc_datasets
from src.models.load import load_model
from src.helpers.torch import clear_mem
from src.models.phi.model_phi import PhiForCausalLMWHS


## Parameters


In [None]:
# params
max_epochs = 3
device = "cuda:0"

cfg = ExtractConfig(
    max_examples=(300, 100),
    
    # model="wassname/phi-1_5-w_hidden_states",
    # batch_size=3,

    # model="wassname/phi-2-w_hidden_states",
    model="Walmart-the-bag/phi-2-uncensored",
    batch_size=1,
    prompt_format="phi",
)


## Load model

In [None]:
model, tokenizer = load_model(
    cfg.model,
    device=device,
    model_class=PhiForCausalLMWHS,
)


In [None]:
# TODO I would like to only have biases, but for now lets just try a very small intervention on the last parts of a layer...
peft_config = LoraConfig(
    target_modules=[
        "out_proj",
        "mlp.fc2",
    ],  # only the layers that go directly to the residual
    # bias="lora_only",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=4,
    lora_alpha=4,
    lora_dropout=0.0,
)


# peft_config = IA3Config(
#     task_type=TaskType.SEQ_CLS, target_modules=[ "out_proj",
#         "mlp.fc2",], feedforward_modules=["out_proj", "mlp.fc2",]
# )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


## Load datasets

In [None]:
N = sum(cfg.max_examples)
ds_tokens = load_preproc_datasets(cfg.datasets,
                      tokenizer,
        N=N,
        seed=cfg.seed,
        num_shots=cfg.num_shots,
        max_length=cfg.max_length,
        prompt_format=cfg.prompt_format,
)
ds_tokens
# datasets2 = []
# for ds_name in cfg.datasets:
#     N = sum(cfg.max_examples)
#     ds_tokens1 = load_preproc_dataset(
#         ds_name,
#         tokenizer,
#         N=N,
#         seed=cfg.seed,
#         num_shots=cfg.num_shots,
#         max_length=cfg.max_length,
#         prompt_format=cfg.prompt_format,
#     ).with_format("torch")
#     datasets2.append(ds_tokens1)
# ds_tokens = datasets.interleave_datasets(datasets2)
# ds_tokens


In [None]:
datasets2 = []
for ds_name in cfg.datasets_oos:
    N = sum(cfg.max_examples)//2
    ds_tokens1 = load_preproc_dataset(
        ds_name,
        tokenizer,
        N=N,
        seed=cfg.seed,
        num_shots=cfg.num_shots,
        max_length=cfg.max_length,
        prompt_format=cfg.prompt_format,
    ).with_format("torch")
    datasets2.append(ds_tokens1)
ds_tokens2 = datasets.concatenate_datasets(datasets2)
ds_tokens2


## Train

In [None]:
dm = DeceptionDataModule(ds_tokens, batch_size=cfg.batch_size)
dm


In [None]:
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()


In [None]:
b = next(iter(dl_train))
print(b.keys(), b["input_ids"].shape)
c_in = b["input_ids"].shape[1]
c_in


## custom models

In [None]:
from src.models.pl_lora_ft import AtapterFinetuner, select_choices

class AtapterFinetunerLie(AtapterFinetuner):
    def get_loss(self, batch, out, out_a):
        """
        simply train it to lie
        """

        log_probs_a = torch.log_softmax(out_a["logits"][:, -1,], -1,)

        # batch['instructed_to_lie']
        lie_label = ~batch['label_true']
        choice_ids1 = batch['choice_ids'][:, lie_label, 0]
        choice_ids2 = batch['choice_ids'][:, lie_label, 1]
        loss1 = F.nll_loss(log_probs_a, target=choice_ids1)        
        loss2 = F.nll_loss(log_probs_a, target=choice_ids2)
        loss = (loss1 + loss2) / 2

        return loss, None, None


In [None]:
net = AtapterFinetunerLie(
    model, tokenizer, lr=5e-3, weight_decay=1e-5, total_steps=len(dl_train) * max_epochs
)

print(c_in)
# net.model.enable_adapters()


In [None]:
# # debug
# with torch.no_grad():
#     o = net.training_step(b, None)
# o


In [None]:
# # debug
# with torch.no_grad():
#     o = net.predict_step(b, None)
# o.keys()


In [None]:
# we want to init lightning early, so it inits accelerate
trainer1 = pl.Trainer(
    # precision="16-true", # leads to inf loss?
    # precision="16-mixed", # works
    # precision="bf16-mixed",
    gradient_clip_val=20,
    # accelerator="auto",
    devices="1",
    accelerator="gpu",
    # devices=[0],
    accumulate_grad_batches=4,
    max_epochs=max_epochs,
    log_every_n_steps=1,
    enable_model_summary=False,
)


In [None]:
trainer1.fit(model=net, train_dataloaders=dl_train, val_dataloaders=dl_val);


In [None]:
checkpoint_path = Path(trainer1.log_dir)/'final'
model.save_pretrained(checkpoint_path)
checkpoint_path


In [None]:
from src.helpers.lightning import read_metrics_csv

_, df_hist = read_metrics_csv(trainer1.logger.experiment.metrics_file_path)
df_hist[['train/loss_step', 'val/loss_step']].plot(style='.')
df_hist


## Generate


In [None]:
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16, # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
clear_mem()


In [None]:
# get a row
bi = 4
inputs = ds_tokens.with_format("torch")[bi]

from src.eval.gen import gen


In [None]:
with model.disable_adapter():
    gen(model, inputs, tokenizer)

gen(model, inputs, tokenizer)


# Test

In [None]:
from src.eval.helpers import test_intervention_quality2
from src.eval.labels import ds2label_model_obey, ds2label_model_truth


In [None]:
dm = DeceptionDataModule(ds_tokens, batch_size=cfg.batch_size * 2)
dl_train2 = dm.train_dataloader()
dl_val2 = dm.val_dataloader()
dl_test2 = dm.test_dataloader()


In [None]:
dl_oos2 = DataLoader(
    ds_tokens2, batch_size=cfg.batch_size * 2, drop_last=False, shuffle=False
)
len(ds_tokens2)


In [None]:
# rs = trainer1.test(
#     net,
#     dataloaders=[
#         # dl_train2, dl_val2,
#         dl_test2,
#         dl_oos2,
#     ],
# )
# rs = rename(rs, ["train", "val", "test", "oos"])
# rs[0]


# Predict

Here we want to see if we can do a probe on the hidden states to see if it's lying...


### Collect

- see how acc each was for instructions vs truth
- see how a linear probe trained on the diff can do for truth, vs baseline

In [None]:
model, tokenizer = model, tokenizer = load_model(
    cfg.model,
    device=device,
    adaptor_path=checkpoint_path,
    dtype=torch.float16, # bfloat can't be pickled
    model_class=PhiForCausalLMWHS,
)
clear_mem()


In [None]:
from src.eval.collect import manual_collect2
ds_out, f = manual_collect2(dl_oos2, model, dataset_name="oos")


### Eval

In [None]:
# TODO limit it to ones where it knows
for label_name, label_fn in dict(label_model_truth=ds2label_model_truth, label_model_obey=ds2label_model_obey).items():
    # fit probe
    print('='*80)
    print('making intervention with', label_name, 'hidden states')
    test_intervention_quality2(ds_out, label_fn, tokenizer)


### Check dataset of outputs

In [None]:
from src.eval.ds import qc_ds, ds2df, qc_dsdf


In [None]:
df = ds2df(ds_out)
df


In [None]:
# TODO one for base, one for adapter
# TODO is acc and lie_acc the same... so it's ignoring the examples and system instrucitons... maybe I need a instruction tuned one?
qc_ds(ds_out)


In [None]:
print('acc by dataset and template name')
df1 = ds2df(ds_out)
df_b = df1.rename(columns=lambda x: x.replace('_base', '')).copy()
df_a = df1.rename(columns=lambda x: x.replace('_adapt', '')).copy()
for ds_string, ddf in df_b.groupby(['ds_string', 'template_name']):
    print(ds_string)
    qc_dsdf(ddf)
    



In [None]:
print('acc by dataset and sys_instr_name')
df = ds2df(ds_out.with_format('numpy')).rename(columns=lambda x: x.replace('_base', ''))
df['ans'] = df['binary_ans'] >0.5
df['label_instructed'] = df['label_true'] ^ df['instructed_to_lie']
for ds_string, ddf in df.groupby(['ds_string','sys_instr_name']):
    print(ds_string)
    qc_dsdf(ddf)
