In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from IPython.display import display, HTML, Markdown
from torch.utils.data import random_split, DataLoader, TensorDataset
from loguru import logger
import sys
import pandas as pd
logger.remove()
logger.add(sys.stderr, format="{time} {level} {message}", level="INFO")

# load my code
%load_ext autoreload
%autoreload 2
from src.eval.collect import manual_collect2
from src.eval.ds import ds2df, qc_ds, qc_dsdf
from src.prompts.prompt_loading import load_prompts, format_prompt, load_preproc_dataset
from src.models.load import load_model



So here we try new datasets, we want them to have the charecteristics

- in the ELK code https://github.com/EleutherAI/elk/tree/1b60b3bff348b00356cd15b5eb017f9c9bfdbae1/elk/promptsource/templates
- acheivable! e.g. acc >70%
- boolean
- not to long (IMBD often becomes over 777 chars) which gets truncated and fills up my mem
- diverse

In [2]:
N = 300
model_id = "Walmart-the-bag/phi-2-uncensored"
print(model_id)
# load model
model, tokenizer = load_model(model_id, dtype=torch.float16)


Walmart-the-bag/phi-2-uncensored


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:

ds_names = [
    'EleutherAI/truthful_qa_binary',
    "amazon_polarity",
    'imdb',
    
    'super_glue:boolq', # yes
    'glue:qnli', # yes
    'super_glue:axg', # yes
    'super_glue:rte',  # yes
    # 'super_glue:cb', # xpected label to be 0 or 1, got 2
    'hans', # yes
    'sst2', # yes
    # 'snli', # Expected label to be 0 or 1, got 2
    # 'xnli:en', # Expected label to be 0 or 1, got 2
    # 'glue:mnli', # Expected label to be 0 or 1, got 2
    # 'glue:mnli_mismatched', # Expected label to be 0 or 1, got 2


    # 'super_glue:copa', # yes
    

    # 'wiqa',  # no label columns
    # "hotpot_qa:fullwiki", # no label columns
    # "hotpot_qa:distractor", # No multiple choice templates found
    # 'biosses', # Dataset has no label column
    # 'commonsense_qa',  #ValueError: Dataset has no label column 
    'sms_spam', # yes
    'wiki_qa', # yes
    # 'math_qa', # no label
    # 'adversarial_qa:droberta', # No multiple choice templates found
    # 'math_dataset:algebra__linear_1d', # o multiple choice templates found
    # 'emotion', #  Expected label to be 0 or 1, got 4
    # 'emo', #  Expected label to be 0 or 1, got 3
    # 'lauritowal/redefine_math', # Dataset has no label column
    # 'boolq_pt', # cannot find
    # 'liar', #  Expected label to be 0 or 1, got 2
    # 'ag_news', #  Expected label to be 0 or 1, got 2
    # 'movie_rationales', # long?
    # 'hate_speech18', # label is dict
]

data = {}
for ds_name in ds_names:
    print('=='*80)
    print(ds_name)
    # load dataset
    try:
        ds = load_preproc_dataset(ds_name, tokenizer, N, prompt_format='phi').with_format("torch")
        ds
    except KeyboardInterrupt:
        raise
    except Exception as e:
        logger.exception(f"Failed to load {ds_name} `{e}`")
        raise
        continue


    # eval
    dl = DataLoader(ds, batch_size=8, shuffle=False, num_workers=0)
    ds_out, f = manual_collect2(dl, model, get_residual=False)
    print(f'for {model_id}:')

    res = qc_ds(ds_out)
    res['truncation']=(ds['truncated']*1.0).mean()
    data[ds_name] = res

    print(ds_name)
    # /



EleutherAI/truthful_qa_binary


Generating train split: 0 examples [00:00, ? examples/s]

Downloading builder script:   0%|          | 0.00/4.89k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.54k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/105k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/817 [00:00<?, ? examples/s]

2023-12-28T08:08:23.445885+0800 ERROR Failed to load EleutherAI/truthful_qa_binary `An error occurred while generating the dataset`
[33m[1mTraceback (most recent call last):[0m

  File "/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.venv/lib/python3.11/site-packages/datasets/builder.py", line 1676, in _prepare_split_single
    for key, record in generator:
                       └ <generator object Generator._generate_examples at 0x7f5c4206bb40>
  File "/media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.venv/lib/python3.11/site-packages/datasets/packaged_modules/generator/generator.py", line 30, in _generate_examples
    for idx, ex in enumerate(self.config.generator(**gen_kwargs)):
                             │    │      │           └ {'ds_string': 'EleutherAI/truthful_qa_binary', 'num_shots': 1, 'split_type': 'train', 'seed': 42, 'N': 900}
                             │    │      └ <function load_prompts at 0x7f5c42a672e0>
             

DatasetGenerationError: An error occurred while generating the dataset

In [None]:
df = pd.DataFrame(data).T.sort_values('auroc')
display(Markdown(df.to_markdown()))
print(df.to_markdown())


|                  |   balance |   N |    auroc |   lie_auroc |   known_lie_auroc |   choice_cov |   truncation |
|:-----------------|----------:|----:|---------:|------------:|------------------:|-------------:|-------------:|
| sms_spam         |  0.486667 | 300 | 0.499377 |    0.531934 |         0.296474  |     0.973791 |            0 |
| wiki_qa          |  0.486667 | 300 | 0.537004 |    0.42528  |         0.3       |     0.84084  |            0 |
| glue:qnli        |  0.486667 | 300 | 0.616972 |    0.409714 |         0.251208  |     0.729142 |            0 |
| hans             |  0.486667 | 300 | 0.649173 |    0.396282 |         0.209524  |     0.99918  |            0 |
| sst2             |  0.486667 | 300 | 0.686622 |    0.351628 |         0.266447  |     0.81738  |            0 |
| super_glue:axg   |  0.486667 | 300 | 0.703611 |    0.310087 |         0.0772059 |     0.999047 |            0 |
| super_glue:boolq |  0.486667 | 300 | 0.734211 |    0.271927 |         0.0744207 |     0.98014  |            0 |
| super_glue:rte   |  0.486667 | 300 | 0.770681 |    0.268991 |         0.0930556 |     0.997921 |            0 |
| imdb             |  0.51     | 300 | 0.781798 |    0.271429 |         0.16313   |     0.886125 |            0 |
| amazon_polarity  |  0.486667 | 300 | 0.835527 |    0.22327  |         0.169878  |     0.963495 |            0 |

|                  |   balance |   N |    auroc |   lie_auroc |   known_lie_auroc |   choice_cov |   truncation |
|:-----------------|----------:|----:|---------:|------------:|------------------:|-------------:|-------------:|
| sms_spam         |  0.486667 | 300 | 0.499377 |    0.531934 |         0.296474  |     0.973791 |            0 |
| wiki_qa          |  0.486667 | 300 | 0.537004 |    0.42528  |         0.3       |     0.84084  |            0 |
| glue:qnli        |  0.486667 | 300 | 0.616972 |    0.409714 |         0.251208  |     0.729142 |            0 |
| hans             |  0.486667 | 300 | 0.649173 |    0.396282 |         0.209524  |     0.99918  |            0 |
| sst2             |  0.486667 | 300 | 0.686622 |    0.351628 |         0.266447  |     0.81738  |            0 |
| super_glue:axg   |  0.486667 | 300 | 0.703611 |    0.310087 |         0.0772059 |     0.999047 |            0 |
| super_glue:boolq |  0.486667 | 300 | 0.734211 |    0.271927 |         0.0744207 |     

|                  |   balance |   N |    auroc |   lie_auroc |   known_lie_auroc |   choice_cov |   truncation |
|:-----------------|----------:|----:|---------:|------------:|------------------:|-------------:|-------------:|
| sms_spam         |  0.486667 | 300 | 0.499377 |    0.531934 |         0.296474  |     0.973791 |            0 |
| wiki_qa          |  0.486667 | 300 | 0.537004 |    0.42528  |         0.3       |     0.84084  |            0 |
| glue:qnli        |  0.486667 | 300 | 0.616972 |    0.409714 |         0.251208  |     0.729142 |            0 |
| hans             |  0.486667 | 300 | 0.649173 |    0.396282 |         0.209524  |     0.99918  |            0 |
| sst2             |  0.486667 | 300 | 0.686622 |    0.351628 |         0.266447  |     0.81738  |            0 |
| super_glue:axg   |  0.486667 | 300 | 0.703611 |    0.310087 |         0.0772059 |     0.999047 |            0 |
| super_glue:boolq |  0.486667 | 300 | 0.734211 |    0.271927 |         0.0744207 |     0.98014  |            0 |
| super_glue:rte   |  0.486667 | 300 | 0.770681 |    0.268991 |         0.0930556 |     0.997921 |            0 |
| imdb             |  0.51     | 300 | 0.781798 |    0.271429 |         0.16313   |     0.886125 |            0 |
| amazon_polarity  |  0.486667 | 300 | 0.835527 |    0.22327  |         0.169878  |     0.963495 |            0 |