In [2]:
# example script to generate boosted predictions
import csv
from tqdm import tqdm
from datasets import load_dataset
from cb import CB
from util import accuracy

# load dataset
lambada = load_dataset("EleutherAI/lambada_openai")
X = lambada["test"]["text"]

# boosting params
alphas = [-0.6, -0.5, -0.5, -0.5]
ks = [10, 11, 10, 9]
models = ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"]

accs = []
for i, model in enumerate(tqdm(models)):
    alpha, k = alphas[i], ks[i]
    cb_model = CB(alpha, k, model_id=model, device="cuda")
    out = cb_model.boosted_batched_generate(X, fmax_score=True, batch_size=64)
    acc = accuracy(out["targets"], out["preds_fmax"])
    acc_boost = accuracy(out["targets"], out["preds_cb"])
    accs.append({"model": model, "fmax": acc, "cb": acc_boost})

print(accs)

Downloading builder script:   0%|          | 0.00/4.82k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.99k [00:00<?, ?B/s]



Downloading and preparing dataset lambada_openai/default to /root/.cache/huggingface/datasets/EleutherAI___lambada_openai/default/1.0.0/57baddecfa09d1790541ef07274c5666abfbe9d2ccd0cd46013cd557b0343095...


Downloading data:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/5153 [00:00<?, ? examples/s]

Dataset lambada_openai downloaded and prepared to /root/.cache/huggingface/datasets/EleutherAI___lambada_openai/default/1.0.0/57baddecfa09d1790541ef07274c5666abfbe9d2ccd0cd46013cd557b0343095. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]


  0%|          | 0/4 [00:00<?, ?it/s][A

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Map:   0%|          | 0/5153 [00:00<?, ? examples/s]

100%|██████████| 81/81 [00:27<00:00,  2.94it/s]

 25%|██▌       | 1/4 [00:45<02:15, 45.11s/it][A

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Map:   0%|          | 0/5153 [00:00<?, ? examples/s]

100%|██████████| 81/81 [01:00<00:00,  1.34it/s]

 50%|█████     | 2/4 [02:10<02:17, 68.63s/it][A

Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Map:   0%|          | 0/5153 [00:00<?, ? examples/s]

100%|██████████| 81/81 [02:04<00:00,  1.54s/it]

 75%|███████▌  | 3/4 [05:02<01:55, 115.85s/it][A

Downloading (…)lve/main/config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Map:   0%|          | 0/5153 [00:00<?, ? examples/s]

100%|██████████| 81/81 [04:01<00:00,  2.99s/it]

100%|██████████| 4/4 [10:04<00:00, 151.13s/it]

[{'model': 'gpt2', 'fmax': tensor(0.4667), 'cb': tensor(0.6470)}, {'model': 'gpt2-medium', 'fmax': tensor(0.5500), 'cb': tensor(0.7141)}, {'model': 'gpt2-large', 'fmax': tensor(0.5876), 'cb': tensor(0.7413)}, {'model': 'gpt2-xl', 'fmax': tensor(0.6142), 'cb': tensor(0.7475)}]





In [16]:
file_name = "boosted_lambada.csv"
with open(file_name, "w") as f:
    w = csv.DictWriter(f, accs[0].keys())
    w.writeheader()
    for row in accs:
        values = [i[1] for i in row.items()][1:]
        out = {'model': row['model'], 'fmax': values[0].item(), 'cb': values[1].item()}
        w.writerow(out)