In [1]:
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import numpy as np
import re
import random
import json
from datasets import Dataset

`get_token_position`: Get the position of a single_token in a string or sequence of tokens.

In [3]:
df = pd.read_csv("../datasets/names.csv")
df.head()

Unnamed: 0,name,gender,number,gpt2-small
0,James,M,5122407,3700
1,John,M,5096818,1757
2,Robert,M,4803587,5199
3,Michael,M,4326215,3899
4,Mary,F,4118147,5335


In [4]:
df["gpt2-small-size"] = df["gpt2-small"].apply(lambda x: len(x.split(",")))

In [5]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer


In [6]:
def to_token_ids(model: HookedTransformer, name: str) -> str:
    """
    Converts a given name string into a comma-separated string of token IDs using the provided model.

    Args:
        model (HookedTransformer): The language model used for tokenization.
        name (str): The name to be tokenized.

    Returns:
        str: A comma-separated string of token IDs representing the input name.
    """
    tokens = model.to_tokens(f" {name}", prepend_bos=False)[0].tolist()
    return ",".join([str(t) for t in tokens])

# df["gpt2-small"] = df["name"].apply(lambda x: to_token_ids(model, x))

In [7]:
# example_prompt = "Is Mary around here?\n-No,"
# example_answer = " he"
# utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=30)

In [8]:
def fill_prompt(prompt: str, name: str, template_config: dict) -> str:
    """
    Fills a prompt template by replacing placeholders with actual values.

    Args:
        prompt (str): The prompt template containing placeholders in square brackets.
        name (str): The name to insert into the prompt.
        template_config (dict): Dictionary with possible replacements for placeholders.

    Returns:
        str: The filled prompt with all placeholders replaced.
    """
    elems = re.findall(r"\[(.*?)\]",prompt)
    for elem in elems:
        if elem != 'name':
            prompt = re.sub(rf"\[{elem}\]", random.choice(template_config[elem]), prompt)
        else:
            prompt = re.sub(r"\[name\]", name, prompt)
    return prompt

def get_ans_tokens(ans: dict, model: HookedTransformer) -> dict:
    """
    Converts answer strings to their corresponding token IDs using the model.

    Args:
        ans (dict): Dictionary mapping keys (e.g., gender) to answer strings.
        model (HookedTransformer): The language model used for tokenization.

    Returns:
        dict: Dictionary mapping keys to the first token ID of each answer string.
    """
    return { k: model.to_tokens(f" {v}", prepend_bos=False)[0].tolist()[0] for k,v in ans.items() }

def get_opposite_gender(gender: str) -> str:
    """
    Returns the opposite gender.

    Args:
        gender (str): "M" for male or "F" for female.

    Returns:
        str: "F" if input is "M", otherwise "M".
    """
    return "F" if gender=="M" else "M"

def get_token_pos(prompt: str, token_str: str) -> list:
    """
    Finds the positions of tokens in a prompt.

    Args:
        prompt (str): The prompt string.
        token_str (str): Comma-separated string of token IDs.

    Returns:
        list: List of positions of each token in the prompt.
    """
    tokens = token_str.split(",")
    return [ model.get_token_position(single_token=int(token), input=prompt, prepend_bos=True) for token in tokens ]

def create_dataset(model: HookedTransformer, df: pd.DataFrame, template_config: dict, dataset_size: int = 1000, subject_token_size: int = 2) -> dict:
    """
    Creates a dataset dictionary for language model probing.

    Args:
        model (HookedTransformer): The language model.
        df (pd.DataFrame): DataFrame containing names and metadata.
        template_config (dict): Configuration for templates and answers.
        dataset_size (int): Number of samples to generate.
        subject_token_size (int): Number of tokens for subject names.

    Returns:
        dict: Dictionary containing dataset fields for further processing.
    """
    ans_tokens = get_ans_tokens(template_config["ans"], model)
    prompt = random.choice(template_config["templates"])

    all_males   = df[ (df["gender"] == "M") & (df["gpt2-small-size"] == subject_token_size) ] 
    all_females = df[ (df["gender"] == "F") & (df["gpt2-small-size"] == subject_token_size) ] 

    males   = all_males.sample(n=dataset_size // 2, weights="number", replace=True)
    females = all_females.sample(n=dataset_size // 2, weights="number", replace=True)
    data    = pd.concat([males, females], axis=0)

    dataset_dict = { "template_id": [], 
                     "prompt": [], 
                     "expected_token":[], 
                     "unexpected_token":[],
                     "expected_token_id":[], 
                     "unexpected_token_id":[],
                     "subject_token_pos":[] }

    for _, row in data.iterrows():
        prompt_filled = fill_prompt(prompt, row["name"], template_config)
        positions = get_token_pos(prompt_filled, row["gpt2-small"])
        gender = row["gender"]
        op_gender = get_opposite_gender(gender)

        dataset_dict["template_id"].append(prompt)
        dataset_dict["prompt"].append(prompt_filled)
        dataset_dict["expected_token"].append(template_config["ans"][gender])
        dataset_dict["expected_token_id"].append(ans_tokens[gender])
        dataset_dict["unexpected_token"].append(template_config["ans"][op_gender])
        dataset_dict["unexpected_token_id"].append(ans_tokens[op_gender])
        dataset_dict["subject_token_pos"].append(positions)
    
    return dataset_dict

In [13]:
# CONFIG
dataset = "subject_pron"
dataset_size = 2000
subject_token_sizes = [1,2,3]

In [14]:
with open("../src/dataset/templates.json", "r", encoding="utf-8") as f:
    jdata = json.load(f)

    for subject_token_size in subject_token_sizes:
        print(f"Creating dataset for subject_token_size={subject_token_size}...")
        mydict = create_dataset(model, df, jdata[dataset], dataset_size=dataset_size, subject_token_size=subject_token_size)
        ds = Dataset.from_dict(mydict)
        ds.save_to_disk(f"../datasets/{dataset}_{subject_token_size}_tokens")

Creating dataset for subject_token_size=1...


Saving the dataset (1/1 shards): 100%|██████████| 2000/2000 [00:00<00:00, 45451.93 examples/s]


Creating dataset for subject_token_size=2...


Saving the dataset (1/1 shards): 100%|██████████| 2000/2000 [00:00<00:00, 330559.48 examples/s]


Creating dataset for subject_token_size=3...


Saving the dataset (1/1 shards): 100%|██████████| 2000/2000 [00:00<00:00, 276934.01 examples/s]
