In [None]:
%load_ext autoreload
%autoreload 2

import os

REPO_DIR = f"{os.getcwd()}"
SRC_DIR = os.path.join(REPO_DIR, "src")
MODEL_DIR = os.path.join(REPO_DIR, "models")
DATA_DIR = os.path.join(REPO_DIR, "data")

for d in [MODEL_DIR, DATA_DIR]:
    if not os.path.exists(d):
        os.makedirs(d)


import sys

sys.path.append(REPO_DIR)
sys.path.append(SRC_DIR)

import random

import numpy as np
import torch
from nnsight import NNsight
from transformers import AutoModelForCausalLM, AutoTokenizer


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(0)

device = "cpu"
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"

# Model

In [None]:
# Load model


with open("../../auth/hf_token.txt") as f:
    hf_token = f.read().strip()

model_id = "google/gemma-2-2b"
model_name = "gemma-2-2b"

torch.set_grad_enabled(False)  # avoid blowing up mem
hf_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir=MODEL_DIR,
    token=hf_token,
    device_map=device,
    low_cpu_mem_usage=True,
    attn_implementation="eager",
)

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    cache_dir=MODEL_DIR,
    token=hf_token,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
VOCAB = sorted(tokenizer.vocab, key=tokenizer.vocab.get)

layer_idx = 10


nnsight_model = NNsight(hf_model)
nnsight_tracer_kwargs = {
    "scan": True,
    "validate": False,
    "use_cache": False,
    "output_attentions": False,
}

# Dataset Generation

In [None]:
from ravel_dataset_builder import RAVELEntityPromptData

full_entity_dataset = RAVELEntityPromptData.from_files("city", "data", tokenizer)
len(full_entity_dataset)

In [None]:
sampled_entity_dataset = full_entity_dataset.downsample(1000)
print(f"Number of prompts remaining: {len(sampled_entity_dataset)}")

prompt_max_length = 48
sampled_entity_dataset.generate_completions(
    nnsight_model,
    tokenizer,
    max_length=prompt_max_length + 8,
    prompt_max_length=prompt_max_length,
)

sampled_entity_dataset.evaluate_correctness()

# Filter correct completions
correct_data = sampled_entity_dataset.filter_correct()

# Filter top entities and templates
filtered_data = correct_data.filter_top_entities_and_templates(
    top_n_entities=400, top_n_templates_per_attribute=12
)

# Calculate average accuracy
accuracy = sampled_entity_dataset.calculate_average_accuracy()
print(f"Average accuracy: {accuracy:.2%}")
print(f"Number of prompts remaining: {len(correct_data)}")

In [None]:
correct_data.add_wikipedia_prompts("city", "data", tokenizer, nnsight_model)

# Experimental Interventions