# Prepare dataset

In [1]:
# autoreload your package
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
from loguru import logger
from tqdm.auto import tqdm
# logger.remove()
# import sys
# logger.add(sys.stderr, level="INFO")

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import pandas as pd
import json
from pathlib import Path

import lie_elicitation_prompts
from lie_elicitation_prompts.config import ExtractConfig
from lie_elicitation_prompts.helpers.scores import row_choice_ids
from lie_elicitation_prompts.prompts.prompt_loading import load_preproc_datasets, load_prompts

cfg = ExtractConfig()
cfg

SyntaxError: invalid syntax (prompt_loading.py, line 310)

## Load text dataset

In [None]:

# # debug
# for ds_name in cfg.datasets:
#     print(ds_name)
#     o = load_prompts(ds_name, num_shots=1, N=2) 
#     o = list(tqdm(o))
#     # print(ds_name, o)
#     1/0
# pd.DataFrame(o)

In [None]:


# Ignore UserWarning category
# warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message="^The groups parameter is ignored by StratifiedShuffleSplit")


In [None]:


N = cfg.max_examples
ds_prompts = load_preproc_datasets(
    cfg.datasets,
    N=N,
    seed=cfg.seed,
    num_shots=cfg.num_shots,
    M=cfg.repeats,
)
ds_prompts

In [None]:
# ds_prompts_ood = load_preproc_datasets(
#     cfg.datasets_ood,
#     N=N,
#     seed=cfg.seed,
#     num_shots=cfg.num_shots,
# )
# ds_prompts_ood

In [None]:
ds_prompts[1]

In [None]:
# save

## Load tokenized dataset

- tokenize
- filter out truncated
- check which ones the model knows

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    cfg.model,
    device_map="auto",
    quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(cfg.model)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [None]:


ds_tokens = (
    ds_prompts.map(
        lambda x: {
            "formatted_chat": tokenizer.apply_chat_template(
                x["messages"], tokenize=False, add_generation_prompt=True
            )
        }
    )
    .map(
        lambda x: tokenizer(
            x["formatted_chat"],
            return_tensors="pt",
            max_length=cfg.max_tokens,
            padding="max_length",
            truncation=True,
        ),
        batched=True,
    )
    .map(lambda r: {"choice_ids": row_choice_ids(r, tokenizer)}, desc="choice_ids")
    .filter(lambda x: x["attention_mask"].sum() < cfg.max_tokens)
)
ds_tokens

In [None]:

# def apply_prompt(messages):
#     o = tokenizer.apply_chat_template(
#                 messages, add_generation_prompt=True, return_tensors="pt",
#             max_length=cfg.max_tokens,
#             padding="max_length",
#             truncation=True,
#             return_dict=True,
#             )
#     return {k:v.squeeze() for k,v in o.items()}

# ds_tokens = (
#     ds_prompts.map(
#         lambda x: apply_prompt(x["messages"])
#     )
#     .map(lambda r: {"choice_ids": row_choice_ids(r, tokenizer)}, desc="choice_ids")
#     .filter(lambda x: x["attention_mask"].sum() < cfg.max_tokens)
# )
# ds_tokens

In [None]:
print(len(ds_prompts), len(ds_tokens))

pd.Series(ds_prompts['ds_string']).value_counts(), pd.Series(ds_tokens['ds_string']).value_counts()

In [None]:
# tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)

ds_tokens[:4]['formatted_chat']

### Check model knowledge

In [None]:
df_metadata = ds_tokens.select_columns(['ds_string', 'example_i', 'sys_instr_name', 'instructed_to_lie']).to_pandas().reset_index(names='my_ds_index')
df_metadata_truth = df_metadata.query('instructed_to_lie == False')
df_metadata_truth

# FIXME right now there is just one example of each, I guess I want a couple, hmm
df_metadata.query('instructed_to_lie == False').groupby(['ds_string', 'example_i'], as_index=False).count()

In [None]:
# # get a single example of a truthful response to each question
# df_metadata = ds_tokens.select_columns(['ds_string', 'example_i', 'sys_instr_name', 'instructed_to_lie']).to_pandas().reset_index(names='my_ds_index')
# df_metadata_truth = df_metadata.query('instructed_to_lie == False').groupby(['ds_string', 'example_i'], as_index=False).first()
# df_metadata_truth

# ds_tokens_truthful = ds_tokens.select(df_metadata_truth.my_ds_index)
# ds_tokens_truthful

ds_tokens_truthful = ds_tokens.select(torch.argwhere(~ds_tokens['instructed_to_lie']))
ds_tokens_truthful

In [None]:
from lie_elicitation_prompts.helpers.torch_helpers import clear_mem
clear_mem()

In [None]:
# TODO, in some dataset we get 50, so totally random. We need to rephrase the same question multiple ways to avoid this
# also check example i refers to the question, not the generated prompts

In [None]:
from torch.utils.data import DataLoader
from lie_elicitation_prompts.helpers.scores import sum_select_choices_from_logits

batch_size = 4

ds = ds_tokens_truthful.select_columns(['ds_string', 'example_i', 'label_true', 'input_ids', 'attention_mask', 'choice_ids'])
dl = DataLoader(ds, batch_size=batch_size, shuffle=True)

model.eval()

results = []

for nb, batch in enumerate(tqdm(dl)):

    # to device
    inputs = {'input_ids': batch['input_ids'].to(model.device), 'attention_mask': batch['attention_mask'].to(model.device)}
    labels = batch['label_true']
    choice_ids = batch['choice_ids'].to(model.device)

    with torch.no_grad():
        out = model(**inputs)

        logits_last = out['logits'][:, -1].detach().cpu()
        p = out['choice_llm_probs'] = sum_select_choices_from_logits(logits_last, choice_ids)
        out['prob_bool'] = p[:, 1] / (torch.sum(p, 1) + 1e-12) # bool prob is the probability of the second choice
        corrects = labels==(out['prob_bool']>0.5)

        for batch_i, correct in enumerate(corrects):
            if batch_i==0 and nb==0:
                # print(i, correct, batch['formatted_prompt'][batch_i])
                # s = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)[batch_i]
                print(nb, nb, correct.item(), out['choice_llm_probs'][:, 1])
            results.append({
                'ds_string': batch['ds_string'][batch_i],
                'example_i': batch['example_i'][batch_i].item(),
                'correct': correct.item(),
                'prob_bool': out['prob_bool'][batch_i].item(),
            })

In [None]:
tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)

In [None]:
# work out which question it knows the answer too
df_res = pd.DataFrame(results)


acc = df_res.groupby('ds_string').correct.mean()
print(f"Accuracy: {acc.to_dict()}")

# TODO we need to make sure it got all version right, not just one

df_known = df_res[df_res.correct][['ds_string', 'example_i']]
df_known

In [None]:
def row_is_known(x):
    k = df_known[df_known.ds_string==x['ds_string']]
    return x['example_i'].item() in k.example_i.values

# filter the dataset to known answers based on ds_string and example_i
ds_tokens_known = ds_tokens.filter(row_is_known)
print(f"{len(ds_tokens)} -> {len(ds_tokens_known)}")
ds_tokens_known

In [None]:
# save
ts = pd.Timestamp.now().strftime('%Y%m%d-%H%M%S')
f = Path(f"../data/extracted_prompts_{ts}")
print(f)
ds_tokens_known.info.description = json.dumps(cfg.__dict__)
ds_tokens_known.save_to_disk(str(f))

## QC

In [None]:
# if it correct, or is it random guessing?
acc = df_res.groupby('ds_string').correct.mean()
print(f"Accuracy: {acc.to_dict()}")

In [None]:
# which source datasets did the known questions come from?
df_ds = ds_tokens_known.to_pandas()
df_ds[['ds_string','sys_instr_name']].value_counts()

In [None]:
df_metadata = ds_tokens.select_columns(['ds_string', 'sys_instr_name', 'answer_choices', 'label_true', 'instructed_to_lie']).to_pandas()

In [None]:
i = 1
print(df_metadata.iloc[i])
# print(ds_tokens['formatted_chat'][i])
print(tokenizer.decode(ds_tokens['input_ids'][i], skip_special_tokens=True))

In [None]:

def apply_prompt(messages):
    o = tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, return_tensors="pt",
            max_length=cfg.max_tokens,
            padding="max_length",
            truncation=True,
            return_dict=True,
            )
    return {k:v.squeeze() for k,v in o.items()}

ds_tokens = (
    ds_prompts.map(
        lambda x: apply_prompt(x["messages"])
    )
    .map(lambda r: {"choice_ids": row_choice_ids(r, tokenizer)}, desc="choice_ids")
    .filter(lambda x: x["attention_mask"].sum() < cfg.max_tokens)
)
ds_tokens

In [None]:
i = 1
print(df_metadata.iloc[i])
# print(ds_tokens['formatted_chat'][i])
print(tokenizer.decode(ds_tokens['input_ids'][i], skip_special_tokens=True))