In [4]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
from tqdm import tqdm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda"

In [5]:
# model_id = "meta-llama/Meta-Llama-3.1-8B"

model_id = "openai-community/gpt2-large"

# model_id = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  device_map='auto',
  torch_dtype=torch.float32,
)

In [6]:
test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt")

Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors


In [7]:
max_length = model.config.max_position_embeddings

stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
texts = []
tokens = []

char_norm_nlls = []

prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    texts.append(tokenizer.decode(input_ids[0]))
    tokens.append(tokenizer.convert_ids_to_tokens(input_ids[0]))

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss


    nlls.append(neg_log_likelihood)

    char_norm_nlls.append((trg_len-1) * neg_log_likelihood / len(texts[-1]))

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break


100%|█████████▉| 560/562 [01:08<00:00,  8.21it/s]


In [13]:
torch.stack(nlls).mean()

tensor(2.8006, device='cuda:0')

In [14]:
torch.exp(torch.stack(nlls).mean())

tensor(16.4541, device='cuda:0')

In [10]:
nlls

[tensor(2.4563, device='cuda:0'),
 tensor(3.1668, device='cuda:0'),
 tensor(3.0118, device='cuda:0'),
 tensor(2.7525, device='cuda:0'),
 tensor(3.0365, device='cuda:0'),
 tensor(2.9613, device='cuda:0'),
 tensor(3.1425, device='cuda:0'),
 tensor(3.3600, device='cuda:0'),
 tensor(3.1691, device='cuda:0'),
 tensor(2.9309, device='cuda:0'),
 tensor(3.1865, device='cuda:0'),
 tensor(3.2540, device='cuda:0'),
 tensor(3.1617, device='cuda:0'),
 tensor(2.5418, device='cuda:0'),
 tensor(3.0972, device='cuda:0'),
 tensor(2.3650, device='cuda:0'),
 tensor(3.0863, device='cuda:0'),
 tensor(2.9568, device='cuda:0'),
 tensor(2.8328, device='cuda:0'),
 tensor(2.6102, device='cuda:0'),
 tensor(2.7252, device='cuda:0'),
 tensor(2.4229, device='cuda:0'),
 tensor(2.1946, device='cuda:0'),
 tensor(2.3087, device='cuda:0'),
 tensor(2.2461, device='cuda:0'),
 tensor(2.7089, device='cuda:0'),
 tensor(2.8490, device='cuda:0'),
 tensor(2.9128, device='cuda:0'),
 tensor(2.7280, device='cuda:0'),
 tensor(2.8926

In [11]:
char_norm_nlls

[tensor(0.5650, device='cuda:0'),
 tensor(0.3705, device='cuda:0'),
 tensor(0.3407, device='cuda:0'),
 tensor(0.3012, device='cuda:0'),
 tensor(0.3410, device='cuda:0'),
 tensor(0.3461, device='cuda:0'),
 tensor(0.3761, device='cuda:0'),
 tensor(0.3804, device='cuda:0'),
 tensor(0.3458, device='cuda:0'),
 tensor(0.3213, device='cuda:0'),
 tensor(0.3776, device='cuda:0'),
 tensor(0.3998, device='cuda:0'),
 tensor(0.3601, device='cuda:0'),
 tensor(0.2942, device='cuda:0'),
 tensor(0.3661, device='cuda:0'),
 tensor(0.2715, device='cuda:0'),
 tensor(0.3603, device='cuda:0'),
 tensor(0.3687, device='cuda:0'),
 tensor(0.3198, device='cuda:0'),
 tensor(0.2850, device='cuda:0'),
 tensor(0.3305, device='cuda:0'),
 tensor(0.3195, device='cuda:0'),
 tensor(0.3164, device='cuda:0'),
 tensor(0.3248, device='cuda:0'),
 tensor(0.3005, device='cuda:0'),
 tensor(0.3323, device='cuda:0'),
 tensor(0.3423, device='cuda:0'),
 tensor(0.3395, device='cuda:0'),
 tensor(0.3103, device='cuda:0'),
 tensor(0.3471

In [12]:
torch.stack(char_norm_nlls).mean()

tensor(0.3168, device='cuda:0')