Experiment to use lora to make a lying model. Here we think of Lora as a probe, as it acts in a very similar way - modifying the residual stream.

Then the hope is it will assist at lie detecting and generalize to unseen dataset

- https://github.dev/JD-P/minihf/blob/b54075c34ef88d9550e37fdf709e78e5a68787c4/lora_tune.py
- https://github.com/jonkrohn/NLP-with-LLMs


This notebook tried without pytorch lightning

In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]= "1"


In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import datasets

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
import transformers
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType, LoftQConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from datasets import Dataset

from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")

# quiet please
torch.set_float32_matmul_precision("medium")
import warnings

warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings(
    "ignore", ".*sampler has shuffling enabled, it is strongly recommended that.*"
)
warnings.filterwarnings("ignore", ".*has been removed as a dependency of.*")


In [None]:
# load my code
%load_ext autoreload
%autoreload 2

from src.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset
from src.models.load import load_model
# from src.prompts.prompt_loading import load_prompt_structure


In [None]:
# params
max_epochs = 100
device = "cuda:0"



In [17]:
# params
max_epochs = 1
device = "cuda:0"
checkpoint_path = "../notebooks/lightning_logs/version_45/final"

cfg = ExtractConfig(
    batch_size=1,
    max_examples=(200, 100),
    intervention_fit_examples=60,
)


In [32]:
model, tokenizer = load_model(
    cfg.model,
    device=device, adaptor_path=checkpoint_path
)


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 20685723-7a2f-4d00-a8f1-2f19f3e18eec)')' thrown while requesting HEAD https://huggingface.co/wassname/phi-2-w_hidden_states/resolve/main/config.json
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s]


In [33]:
model.disable_adapter()


<contextlib._GeneratorContextManager at 0x7f51d9209e10>

In [13]:
N = sum(cfg.max_examples)
ds_name = "amazon_polarity"
ds_tokens = load_preproc_dataset(
    ds_name,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
).with_format("torch")


prompt_truncated: 100%|██████████| 902/902 [00:01<00:00, 487.56 examples/s]
choice_ids: 100%|██████████| 902/902 [00:00<00:00, 9485.47 examples/s]
[32m2023-12-19 06:40:44.915[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m364[0m - [1mmedian token length: 303.5 for amazon_polarity. max_length=777[0m
2023-12-19T06:40:44.915287+0800 INFO median token length: 303.5 for amazon_polarity. max_length=777
2023-12-19T06:40:44.915287+0800 INFO median token length: 303.5 for amazon_polarity. max_length=777
[32m2023-12-19 06:40:44.916[0m | [1mINFO    [0m | [36msrc.prompts.prompt_loading[0m:[36mload_preproc_dataset[0m:[36m368[0m - [1mtruncation rate: 0.00% on amazon_polarity[0m
2023-12-19T06:40:44.916393+0800 INFO truncation rate: 0.00% on amazon_polarity
2023-12-19T06:40:44.916393+0800 INFO truncation rate: 0.00% on amazon_polarity
Filter: 100%|██████████| 902/902 [00:00<00:00, 2919.86 examples/s]
Filter: 100%|██████████| 902/902 [00:

## Lora train

In [15]:
N = len(ds_tokens)
ds_train = ds_tokens.select(range(N//2))
ds_val = ds_tokens.select(range(N//2, N))


In [31]:
from peft import PeftModel
model = peft.PeftModel.from_pretrained(model, evaluator_adapter_name, "evaluator")
model = PeftModel.from_pretrained(model, checkpoint_path)
# model.add_adapter(peft_config)
model.print_trainable_parameters()
# model


NameError: name 'peft_config' is not defined

In [19]:
# get a row
bi = cfg.intervention_fit_examples + 2
inputs = ds_tokens.with_format("torch")[bi]

# tokenize if needed
if "input_ids" not in inputs:
    model_inputs = self.tokenizer(
        inputs["question"],
        return_tensors=True,
        return_attention_mask=True,
        add_special_tokens=True,
        truncation=True,
        padding="max_length",
        max_length=cfg.max_length,
        **tokenize_kwargs,
    )
    model_inputs = model_inputs.to(device)
    inputs = {**inputs, **model_inputs}

inputs.keys()


dict_keys(['ds_string', 'example_i', 'answer', 'messages', 'answer_choices', 'template_name', 'label_true', 'label_instructed', 'instructed_to_lie', 'sys_instr_name', 'question', 'input_ids', 'attention_mask', 'truncated', 'length', 'prompt_truncated', 'choice_ids'])

In [26]:
from IPython.display import display, HTML

# generate
# https://huggingface.co/docs/transformers/v4.34.1/en/main_classes/text_generation#transformers.GenerationConfig


@torch.no_grad()
def gen(model):
    s = model.generate(
        input_ids=inputs["input_ids"][None, :].to(model.device),
        attention_mask=inputs["attention_mask"][None, :]
        .to(model.device)
        .to(model.dtype),
        use_cache=False,
        max_new_tokens=20,
        min_new_tokens=20,
        do_sample=False,
        early_stopping=False,
    )
    input_l = inputs["input_ids"].shape[0]
    old = tokenizer.decode(
        s[0, :input_l], clean_up_tokenization_spaces=False, skip_special_tokens=False
    )
    new = tokenizer.decode(
        s[0, input_l:], clean_up_tokenization_spaces=False, skip_special_tokens=False
    )
    display(HTML(f"<pre>{old}</pre><b><pre>{new}</pre></b>"))


In [41]:
with model.disable_adapter():
    gen(model)




In [42]:
gen(model)
