# A scratch pad to run model inference manually

Prioritise small experiments in notebooks
- take all the recorded hidden states and seperate into truth and deception
- try a normal intervention and test it
- try novel sgb bias intervention

In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

plt.style.use("ggplot")

from typing import Optional, List, Dict, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch import optim
from torch.utils.data import random_split, DataLoader, TensorDataset

from pathlib import Path
import transformers


from loguru import logger

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")


In [None]:
# load my code
%load_ext autoreload
%autoreload 2


from src.extraction.config import ExtractConfig
from src.prompts.prompt_loading import load_preproc_dataset
from src.models.load import load_model
from src.datasets.intervene import create_cache_interventions
from src.prompts.prompt_loading import load_prompt_structure
from src.repe import repe_pipeline_registry

repe_pipeline_registry()


In [None]:
# # config transformers
# from datasets import set_caching_enabled, disable_caching
# disable_caching()
# os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [None]:
model_name = "phi-2-GPTQ_w_hidden_states"
[str(s) for s in sorted(Path("../.ds/").glob(f"*{model_name}*"))]


In [None]:
from datasets import load_from_disk, concatenate_datasets
from src.datasets.load import ds2df, load_ds, get_ds_name
from src.datasets.load import ds2df, load_ds, get_ds_name, filter_ds_to_known


In [None]:

fs = [
    "../.ds/wassname_phi-2-GPTQ_w_hidden_states_amazon_polarity_test_600",
    # "../.ds/wassname_phi-2-GPTQ_w_hidden_states_amazon_polarity_train_3600",
    "../.ds/wassname_phi-2-GPTQ_w_hidden_states_glue_qnli_test_600",
    # "../.ds/wassname_phi-2-GPTQ_w_hidden_states_glue_qnli_train_3600",
    "../.ds/wassname_phi-2-GPTQ_w_hidden_states_imdb_test_600",
    # "../.ds/wassname_phi-2-GPTQ_w_hidden_states_imdb_train_3600",
    "../.ds/wassname_phi-2-GPTQ_w_hidden_states_super_glue_boolq_test_600",
    # "../.ds/wassname_phi-2-GPTQ_w_hidden_states_super_glue_boolq_train_3600",
]
dss = [load_ds(f) for f in fs]
dss


In [None]:
from src.datasets.load import ds2df, load_ds, get_ds_name, filter_ds_to_known, qc_ds
for ds in dss:
    qc_ds(ds)
    # ds = ds.with_format("numpy")
    


In [19]:
# combine
dss_known = [filter_ds_to_known(d) for d in dss]
ds = concatenate_datasets(dss_known)
ds = ds.with_format("numpy")
ds


select rows are 93.00% based on knowledge
select rows are 61.33% based on knowledge
select rows are 84.14% based on knowledge
select rows are 82.00% based on knowledge


Dataset({
    features: ['end_hidden_states', 'end_logits', 'choice_probs', 'label_true', 'instructed_to_lie', 'question', 'answer_choices', 'choice_ids', 'template_name', 'sys_instr_name', 'example_i', 'input_truncated', 'truncated', 'text_ans', 'ans'],
    num_rows: 1870
})

In [20]:
# QC: make sure we didn't lose all of the successful lies, which would make the problem trivial
df2 = ds2df(ds)
df_subset_successull_lies = df2.query(
    "instructed_to_lie==True & ((llm_ans==1)==label_instructed)"
)
print(
    f"after filtering we have {len(df_subset_successull_lies)} num successful lies out of {len(df2)} dataset rows"
)
assert (
    len(df_subset_successull_lies) > 0
), "there should be successful lies in the dataset"


after filtering we have 137 num successful lies out of 1870 dataset rows


## Load model

In [None]:
ds_name = "amazon_polarity"
cfg = ExtractConfig(
    max_examples=(400, 400),
    intervention_fit_examples=160,
)
print(cfg)
batch_size = cfg.batch_size

model, tokenizer = load_model(
    cfg.model, pad_token_id=cfg.pad_token_id, disable_exllama=False
)
print(model)

N_train, N_test = cfg.max_examples
N = sum(cfg.max_examples)
ds_tokens = load_preproc_dataset(
    ds_name,
    tokenizer,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    max_length=cfg.max_length,
    prompt_format=cfg.prompt_format,
)


## Intervention