In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from IPython.display import display, HTML, Markdown
from torch.utils.data import random_split, DataLoader, TensorDataset
from loguru import logger
import sys
import pandas as pd
logger.remove()
logger.add(sys.stderr, format="{time} {level} {message}", level="INFO")

# load my code
%load_ext autoreload
%autoreload 2
from src.eval.collect import manual_collect2
from src.eval.ds import ds2df, qc_ds, qc_dsdf
from src.prompts.prompt_loading import load_prompts, format_prompt, load_preproc_dataset
from src.models.load import load_model



So here we try new datasets, we want them to have the charecteristics

- in the ELK code https://github.com/EleutherAI/elk/tree/1b60b3bff348b00356cd15b5eb017f9c9bfdbae1/elk/promptsource/templates
- acheivable! e.g. acc >70%
- boolean
- not to long (IMBD often becomes over 777 chars) which gets truncated and fills up my mem
- diverse

In [10]:
N = 300
model_id = "Walmart-the-bag/phi-2-uncensored"
print(model_id)
# load model
model, tokenizer = load_model(model_id, dtype=torch.float16)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [11]:

ds_names = [
    "amazon_polarity",
    'imdb',
    
    'super_glue:boolq', # yes
    'glue:qnli', # yes
    'super_glue:axg', # yes
    'super_glue:rte',  # yes
    # 'super_glue:cb', # xpected label to be 0 or 1, got 2
    'hans', # yes
    'sst2', # yes
    # 'snli', # Expected label to be 0 or 1, got 2
    # 'xnli:en', # Expected label to be 0 or 1, got 2
    # 'glue:mnli', # Expected label to be 0 or 1, got 2
    # 'glue:mnli_mismatched', # Expected label to be 0 or 1, got 2


    # 'super_glue:copa', # yes
    

    # 'wiqa',  # no label columns
    # "hotpot_qa:fullwiki", # no label columns
    # "hotpot_qa:distractor", # No multiple choice templates found
    # 'biosses', # Dataset has no label column
    # 'commonsense_qa',  #ValueError: Dataset has no label column 
    'sms_spam', # yes
    'wiki_qa', # yes
    # 'math_qa', # no label
    # 'adversarial_qa:droberta', # No multiple choice templates found
    # 'math_dataset:algebra__linear_1d', # o multiple choice templates found
    # 'emotion', #  Expected label to be 0 or 1, got 4
    # 'emo', #  Expected label to be 0 or 1, got 3
    # 'lauritowal/redefine_math', # Dataset has no label column
    # 'boolq_pt', # cannot find
    # 'liar', #  Expected label to be 0 or 1, got 2
    # 'ag_news', #  Expected label to be 0 or 1, got 2
    # 'movie_rationales', # long?
    # 'hate_speech18', # label is dict
]

data = {}
for ds_name in ds_names:
    print('=='*80)
    print(ds_name)
    # load dataset
    try:
        ds = load_preproc_dataset(ds_name, tokenizer, N, prompt_format='phi').with_format("torch")
        ds
    except KeyboardInterrupt:
        raise
    except Exception as e:
        logger.exception(f"Failed to load {ds_name} `{e}`")
        raise
        continue


    # eval
    dl = DataLoader(ds, batch_size=8, shuffle=False, num_workers=0)
    ds_out, f = manual_collect2(dl, model, get_residual=False)
    print(f'for {model_id}:')

    res = qc_ds(ds_out)
    res['truncation']=(ds['truncated']*1.0).mean()
    data[ds_name] = res

    print(ds_name)
    # /



amazon_polarity


Generating train split: 0 examples [00:00, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)
2023-12-25T18:56:43.293682+0800 INFO Extracting 11 variants of each prompt
2023-12-25T18:57:31.087941+0800 INFO setting tokenizer chat template to phi


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T18:57:35.045782+0800 INFO median token length: 274.5 for amazon_polarity. max_length=999
2023-12-25T18:57:35.046584+0800 INFO truncation rate: 0.00% on amazon_polarity


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T18:57:35.899972+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T18:57:38.767048+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__fbbf3b9524dc05ba


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
amazon_polarity
imdb


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T18:59:56.760896+0800 INFO Extracting 13 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:00:42.277937+0800 INFO median token length: 562.5 for imdb. max_length=999
2023-12-25T19:00:42.278640+0800 INFO truncation rate: 0.89% on imdb


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/894 [00:00<?, ? examples/s]

2023-12-25T19:00:43.332471+0800 INFO num_rows (after filtering out truncated rows) 902=>894
2023-12-25T19:00:46.202352+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__e7fbba5be777e667


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
imdb
super_glue:boolq


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:03:02.887858+0800 INFO Extracting 10 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:03:50.706813+0800 INFO median token length: 320.0 for super_glue:boolq. max_length=999
2023-12-25T19:03:50.707552+0800 INFO truncation rate: 0.00% on super_glue:boolq


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:03:51.581389+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:03:54.601374+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__b8fd18ad5d5ba447


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
super_glue:boolq
glue:qnli


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:06:15.317821+0800 INFO Extracting 5 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:06:39.229122+0800 INFO median token length: 175.0 for glue:qnli. max_length=999
2023-12-25T19:06:39.229910+0800 INFO truncation rate: 0.00% on glue:qnli


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:06:40.077809+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:06:43.355000+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__a1f22eb25e1e6760


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
glue:qnli
super_glue:axg


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:09:13.441273+0800 INFO Extracting 10 variants of each prompt


format_prompt:   0%|          | 0/712 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/712 [00:00<?, ? examples/s]

truncated:   0%|          | 0/712 [00:00<?, ? examples/s]

truncated:   0%|          | 0/712 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/712 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/712 [00:00<?, ? examples/s]

2023-12-25T19:09:47.606234+0800 INFO median token length: 118.0 for super_glue:axg. max_length=999
2023-12-25T19:09:47.607190+0800 INFO truncation rate: 0.00% on super_glue:axg


Filter:   0%|          | 0/712 [00:00<?, ? examples/s]

Filter:   0%|          | 0/712 [00:00<?, ? examples/s]

2023-12-25T19:09:48.300005+0800 INFO num_rows (after filtering out truncated rows) 712=>712
2023-12-25T19:09:51.507018+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__00136c9ca2a82d74


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
super_glue:axg
super_glue:rte


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:12:21.009493+0800 INFO Extracting 11 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:13:11.795944+0800 INFO median token length: 187.0 for super_glue:rte. max_length=999
2023-12-25T19:13:11.796762+0800 INFO truncation rate: 0.00% on super_glue:rte


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:13:12.889205+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:13:16.183107+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__b90de14b6580f7f9


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
super_glue:rte
hans


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:15:44.431654+0800 INFO Extracting 10 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:16:26.534858+0800 INFO median token length: 98.0 for hans. max_length=999
2023-12-25T19:16:26.535561+0800 INFO truncation rate: 0.00% on hans


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:16:27.398862+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:16:30.641925+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__792a7d3313d88d8f


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
hans
sst2


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:18:55.998230+0800 INFO Extracting 15 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:19:46.898149+0800 INFO median token length: 87.0 for sst2. max_length=999
2023-12-25T19:19:46.898958+0800 INFO truncation rate: 0.00% on sst2


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:19:47.747019+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:19:50.976962+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__b52bc63561199047


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
sst2
sms_spam


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:22:15.680339+0800 INFO Extracting 5 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:22:37.035195+0800 INFO median token length: 123.0 for sms_spam. max_length=999
2023-12-25T19:22:37.035905+0800 INFO truncation rate: 0.00% on sms_spam


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:22:37.884771+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:22:41.085678+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__db2d956bc9b9d82c


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
sms_spam
wiki_qa


Generating train split: 0 examples [00:00, ? examples/s]

2023-12-25T19:25:07.686477+0800 INFO Extracting 5 variants of each prompt


format_prompt:   0%|          | 0/902 [00:00<?, ? examples/s]

tokenize:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

prompt_truncated:   0%|          | 0/902 [00:00<?, ? examples/s]

choice_ids:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:25:33.515568+0800 INFO median token length: 161.0 for wiki_qa. max_length=999
2023-12-25T19:25:33.516267+0800 INFO truncation rate: 0.00% on wiki_qa


Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

Filter:   0%|          | 0/902 [00:00<?, ? examples/s]

2023-12-25T19:25:34.379653+0800 INFO num_rows (after filtering out truncated rows) 902=>902
2023-12-25T19:25:37.457950+0800 INFO creating dataset /media/wassname/SGIronWolf/projects5/elk/sgd_probes_are_lie_detectors/.ds/ds__28a4ae9cea92d30b


collecting hidden states:   0%|          | 0/38 [00:00<?, ?it/s]

for Walmart-the-bag/phi-2-uncensored:
with base model
wiki_qa


In [17]:
df = pd.DataFrame(data).T.sort_values('auroc')
display(Markdown(df.to_markdown()))
print(df.to_markdown())


|                  |   balance |   N |    auroc |   lie_auroc |   known_lie_auroc |   choice_cov |   truncation |
|:-----------------|----------:|----:|---------:|------------:|------------------:|-------------:|-------------:|
| sms_spam         |  0.486667 | 300 | 0.499377 |    0.531934 |         0.296474  |     0.973791 |            0 |
| wiki_qa          |  0.486667 | 300 | 0.537004 |    0.42528  |         0.3       |     0.84084  |            0 |
| glue:qnli        |  0.486667 | 300 | 0.616972 |    0.409714 |         0.251208  |     0.729142 |            0 |
| hans             |  0.486667 | 300 | 0.649173 |    0.396282 |         0.209524  |     0.99918  |            0 |
| sst2             |  0.486667 | 300 | 0.686622 |    0.351628 |         0.266447  |     0.81738  |            0 |
| super_glue:axg   |  0.486667 | 300 | 0.703611 |    0.310087 |         0.0772059 |     0.999047 |            0 |
| super_glue:boolq |  0.486667 | 300 | 0.734211 |    0.271927 |         0.0744207 |     0.98014  |            0 |
| super_glue:rte   |  0.486667 | 300 | 0.770681 |    0.268991 |         0.0930556 |     0.997921 |            0 |
| imdb             |  0.51     | 300 | 0.781798 |    0.271429 |         0.16313   |     0.886125 |            0 |
| amazon_polarity  |  0.486667 | 300 | 0.835527 |    0.22327  |         0.169878  |     0.963495 |            0 |

|                  |   balance |   N |    auroc |   lie_auroc |   known_lie_auroc |   choice_cov |   truncation |
|:-----------------|----------:|----:|---------:|------------:|------------------:|-------------:|-------------:|
| sms_spam         |  0.486667 | 300 | 0.499377 |    0.531934 |         0.296474  |     0.973791 |            0 |
| wiki_qa          |  0.486667 | 300 | 0.537004 |    0.42528  |         0.3       |     0.84084  |            0 |
| glue:qnli        |  0.486667 | 300 | 0.616972 |    0.409714 |         0.251208  |     0.729142 |            0 |
| hans             |  0.486667 | 300 | 0.649173 |    0.396282 |         0.209524  |     0.99918  |            0 |
| sst2             |  0.486667 | 300 | 0.686622 |    0.351628 |         0.266447  |     0.81738  |            0 |
| super_glue:axg   |  0.486667 | 300 | 0.703611 |    0.310087 |         0.0772059 |     0.999047 |            0 |
| super_glue:boolq |  0.486667 | 300 | 0.734211 |    0.271927 |         0.0744207 |     

|                  |   balance |   N |    auroc |   lie_auroc |   known_lie_auroc |   choice_cov |   truncation |
|:-----------------|----------:|----:|---------:|------------:|------------------:|-------------:|-------------:|
| sms_spam         |  0.486667 | 300 | 0.499377 |    0.531934 |         0.296474  |     0.973791 |            0 |
| wiki_qa          |  0.486667 | 300 | 0.537004 |    0.42528  |         0.3       |     0.84084  |            0 |
| glue:qnli        |  0.486667 | 300 | 0.616972 |    0.409714 |         0.251208  |     0.729142 |            0 |
| hans             |  0.486667 | 300 | 0.649173 |    0.396282 |         0.209524  |     0.99918  |            0 |
| sst2             |  0.486667 | 300 | 0.686622 |    0.351628 |         0.266447  |     0.81738  |            0 |
| super_glue:axg   |  0.486667 | 300 | 0.703611 |    0.310087 |         0.0772059 |     0.999047 |            0 |
| super_glue:boolq |  0.486667 | 300 | 0.734211 |    0.271927 |         0.0744207 |     0.98014  |            0 |
| super_glue:rte   |  0.486667 | 300 | 0.770681 |    0.268991 |         0.0930556 |     0.997921 |            0 |
| imdb             |  0.51     | 300 | 0.781798 |    0.271429 |         0.16313   |     0.886125 |            0 |
| amazon_polarity  |  0.486667 | 300 | 0.835527 |    0.22327  |         0.169878  |     0.963495 |            0 |