### Perplexity
https://huggingface.co/docs/transformers/perplexity

In [2]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

model_id = "sshleifer/tiny-gpt2" # "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

Downloading (…)lve/main/config.json: 100%|██████████| 662/662 [00:00<00:00, 1.91MB/s]
Downloading pytorch_model.bin: 100%|██████████| 2.51M/2.51M [00:00<00:00, 9.83MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 26.0/26.0 [00:00<00:00, 58.2kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 4.48MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 4.25MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 90.0/90.0 [00:00<00:00, 235kB/s]


In [4]:
from datasets import load_dataset

test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")

Downloading builder script: 100%|██████████| 8.48k/8.48k [00:00<00:00, 17.1MB/s]
Downloading metadata: 100%|██████████| 6.84k/6.84k [00:00<00:00, 21.3MB/s]
Downloading readme: 100%|██████████| 9.62k/9.62k [00:00<00:00, 10.2MB/s]
Downloading data: 100%|██████████| 4.72M/4.72M [00:01<00:00, 4.29MB/s]
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 35782.92 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 58153.04 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 65861.97 examples/s]
Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors


In [8]:
import torch
from tqdm import tqdm

max_length = model.config.n_positions
stride = 128
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, 1024)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc]
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())

100%|█████████▉| 2240/2248 [07:45<00:01,  4.81it/s]


In [26]:
outputs

CausalLMOutputWithCrossAttentions(loss=tensor(10.8185), logits=tensor([[[ 1.2894e-02,  1.0899e-02, -3.1657e-02,  ...,  3.7258e-02,
          -3.5729e-05, -2.5439e-02],
         [ 1.3005e-02,  1.0992e-02, -3.1928e-02,  ...,  3.7578e-02,
          -3.6035e-05, -2.5657e-02],
         [-8.3406e-03, -7.0500e-03,  2.0477e-02,  ..., -2.4101e-02,
           2.3111e-05,  1.6455e-02],
         ...,
         [-1.2938e-02, -1.0936e-02,  3.1765e-02,  ..., -3.7386e-02,
           3.5851e-05,  2.5526e-02],
         [-1.3036e-02, -1.1019e-02,  3.2006e-02,  ..., -3.7670e-02,
           3.6122e-05,  2.5720e-02],
         [-1.2218e-02, -1.0327e-02,  2.9996e-02,  ..., -3.5304e-02,
           3.3855e-05,  2.4105e-02]]]), past_key_values=((tensor([[[[ 0.0074],
          [ 0.0075],
          [-0.0048],
          ...,
          [-0.0074],
          [-0.0075],
          [-0.0070]],

         [[-0.0064],
          [-0.0064],
          [ 0.0042],
          ...,
          [ 0.0064],
          [ 0.0065],
         

In [24]:
# print(begin_loc, end_loc, encodings.input_ids[:, begin_loc:end_loc])
target_ids

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  

In [18]:
model.config.n_positions

1024

In [12]:
nlls[0]

tensor(10.8245)

In [16]:
encodings.input_ids.size()

torch.Size([1, 287644])

In [17]:
list(range(0, seq_len, 1024))

[0,
 1024,
 2048,
 3072,
 4096,
 5120,
 6144,
 7168,
 8192,
 9216,
 10240,
 11264,
 12288,
 13312,
 14336,
 15360,
 16384,
 17408,
 18432,
 19456,
 20480,
 21504,
 22528,
 23552,
 24576,
 25600,
 26624,
 27648,
 28672,
 29696,
 30720,
 31744,
 32768,
 33792,
 34816,
 35840,
 36864,
 37888,
 38912,
 39936,
 40960,
 41984,
 43008,
 44032,
 45056,
 46080,
 47104,
 48128,
 49152,
 50176,
 51200,
 52224,
 53248,
 54272,
 55296,
 56320,
 57344,
 58368,
 59392,
 60416,
 61440,
 62464,
 63488,
 64512,
 65536,
 66560,
 67584,
 68608,
 69632,
 70656,
 71680,
 72704,
 73728,
 74752,
 75776,
 76800,
 77824,
 78848,
 79872,
 80896,
 81920,
 82944,
 83968,
 84992,
 86016,
 87040,
 88064,
 89088,
 90112,
 91136,
 92160,
 93184,
 94208,
 95232,
 96256,
 97280,
 98304,
 99328,
 100352,
 101376,
 102400,
 103424,
 104448,
 105472,
 106496,
 107520,
 108544,
 109568,
 110592,
 111616,
 112640,
 113664,
 114688,
 115712,
 116736,
 117760,
 118784,
 119808,
 120832,
 121856,
 122880,
 123904,
 124928,
 125