In [51]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from dotenv import load_dotenv
from tqdm.notebook import tqdm
import numpy as np
import torch
import os

load_dotenv("../agent-scaling-laws/metr-standard/workbench/.env")

os.environ["HUGGINGFACE_TOKEN"] = os.getenv("HF_TOKEN")

In [None]:
# Change model_name to use a different model. Choices:
# meta-llama/Meta-Llama-3.1-8B
# meta-llama/Meta-Llama-3.1-8B-Instruct
# meta-llama/Llama-2-7b-hf 
# meta-llama/Llama-2-7b-chat-hf d
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [53]:
def get_nll_loss(text: str):
    inputs = tokenizer(text, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs.input_ids)

    return outputs.loss  # loss = mean NLL

def get_text_samples(
    dataset_name = "HuggingFaceFW/fineweb",
    n_samples= 1000,
    max_sample_char_length= 500
    ):
    ds = load_dataset(dataset_name, "default", streaming=True)
    subset = ds["train"].take(n_samples)
    text_samples = [item["text"][: max_sample_char_length] for item in subset]
    return text_samples

def prompt_format(text: str, model_name: str):
    """Prompt formats:
https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#supported-roles
https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2/"""

    # We .strip() to be able to compare with together.ai
    if model_name == "meta-llama/Meta-Llama-3.1-8B":
        return text.strip()
    elif model_name == "meta-llama/Meta-Llama-3.1-8B-Instruct":
        return f"<|start_header_id|>user<|end_header_id|>\n\n{text.strip()}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    elif model_name == "meta-llama/Llama-2-7b-chat-hf":
        return f"[INST] {text.strip()} [/INST]"
    elif model_name == "meta-llama/Llama-2-7b-hf":
        return text.strip()

In [None]:
text_samples = get_text_samples()
text_samples = [prompt_format(text, model_name) for text in text_samples]
print(text_samples[0])

In [55]:
nlls = []
total = 0

# Option 1: Using tqdm.notebook
with tqdm(text_samples, desc="Calculating NLL") as pbar:
    for i, text in enumerate(pbar, start=1):
        nll = get_nll_loss(text)
        nlls.append(nll)
        total += nll
        current_mean = total / i
        pbar.set_postfix({'mean NLL': f'{current_mean:.4f}'})

print(f"Final mean NLL: {np.mean(nlls):.4f}")

Calculating NLL:   0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 