In [None]:
from names_dataset import NameDataset, NameWrapper
import random
from transformers import AutoTokenizer
nd = NameDataset()
import os
import csv
from tqdm import tqdm
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-1.4b-deduped')

In [None]:
def get_one_token_names_list(tokenizer, nd):
    names = list(nd.first_names.keys())
    one_token_names = [name for name in names if len(tokenizer.tokenize(name)) == 1]
    return one_token_names

In [None]:
template_4_flips_clean = lambda Name1, Name2, Name3, Name4, flip1, flip2, flip3, flip4: (
    f"A coin is heads up. "
    f"{Name1} {'flips coin' if flip1 else 'does not flip'}. "
    f"{Name2} {'flips coin' if flip2 else 'does not flip'}. "
    f"{Name3} {'flips coin' if flip3 else 'does not flip'}. "
    f"{Name4} {'flips coin' if flip4 else 'does not flip'}. "
    "Is the coin still heads up? "
)

template_4_flips_corrupted = lambda Name1, Name2, Name3, Name4, flip1, flip2, flip3, flip4: (
    f"A coin is heads up. "
    f"{Name1} {'flips coin' if not flip1 else 'does not flip'}. "
    f"{Name2} {'flips coin' if flip2 else 'does not flip'}. "
    f"{Name3} {'flips coin' if flip3 else 'does not flip'}. "
    f"{Name4} {'flips coin' if flip4 else 'does not flip'}. "
    "Is the coin still heads up? "
)

##### Check that the prompt lengths are the same length!

In [63]:
def made_up_heuristics(flips):
    #the flips location in the array, determines how many flips are done. e.g for position 1, one flip, postion 2, two flips.
    total_flips = 0
    for i, flip in enumerate(flips):
        if flip == 1:
            total_flips += (i + 1) 
    return total_flips

In [64]:
one_token_names = get_one_token_names_list(tokenizer, nd)

['Ab', 'Abs', 'Ac', 'Accept', 'Accessor', 'According', 'Act', 'Active', 'Ad', 'Ada', 'Adam', 'Add', 'Addition', 'Admin', 'Adv', 'Advertisement', 'Af', 'Aff', 'African', 'After', 'Ag', 'Age', 'Agg', 'Ah', 'Air', 'Ak', 'Akt', 'Al', 'Ala', 'Alan', 'Ald', 'Aldrich', 'Ale', 'Alex', 'Alexander', 'Ali', 'Alias', 'Align', 'All', 'Allen', 'Allow', 'Along', 'Alpha', 'Alright', 'Also', 'Alt', 'Am', 'Ama', 'Amazon', 'Amb', 'Americ', 'America', 'Among', 'An', 'Analy', 'Analysis', 'Anchor', 'Anderson', 'Andrew', 'Ang', 'Angel', 'Angle', 'Anim', 'Ann', 'Anna', 'Anne', 'Answer', 'Ant', 'Anth', 'Anthony', 'Anti', 'Anton', 'Any', 'Anyway', 'Ap', 'Apart', 'Api', 'App', 'Appe', 'Appeal', 'Apple', 'April', 'Ar', 'Arab', 'Arc', 'Arch', 'Are', 'Area', 'Arg', 'Ari', 'Arm', 'Arn', 'Around', 'Arr', 'Art', 'Arthur', 'Article', 'As', 'Asc', 'Ash', 'Asia', 'Ask', 'Asp', 'Ass', 'Asset', 'Ast', 'At', 'Ath', 'Att', 'Au', 'Aud', 'Aug', 'August', 'Aust', 'Austin', 'Austral', 'Australia', 'Aut', 'Author', 'Av', 'Aw', 'A

In [65]:
names = [random.choice(one_token_names) for _ in range(4)]
flips = [random.choice([0, 1]) for _ in range(4)]
args = names + flips
clean_entry = template_4_flips_clean(*args)
corrupted_entry = template_4_flips_corrupted(*args)

In [67]:
train_num = 5000
test_num = 1000

In [68]:
data_train = [['clean', 'corrupted','answer']]
data_test = [['clean', 'corrupted','answer']]

for i in range(train_num):
    names = [random.choice(one_token_names) for _ in range(4)]
    flips = [random.choice([0, 1]) for _ in range(4)]
    args = names + flips
    clean_entry = template_4_flips_clean(*args)
    corrupted_entry = template_4_flips_corrupted(*args)
    total_flips = sum(flips)
    answer = total_flips % 2 == 0
    data_train.append([clean_entry, corrupted_entry, answer])

for i in range(test_num):
    names = [random.choice(one_token_names) for _ in range(4)]
    flips = [random.choice([0, 1]) for _ in range(4)]
    args = names + flips
    clean_entry = template_4_flips_clean(*args)
    corrupted_entry = template_4_flips_corrupted(*args)
    total_flips = sum(flips)
    answer = total_flips % 2 == 0
    data_test.append([clean_entry, corrupted_entry, answer])

folder_path = 'datasets_csv/prompts_id_0'

os.makedirs(folder_path, exist_ok=True)
with open('datasets_csv/prompts_id_0/train.csv', 'w', newline='', encoding='utf-8') as f_train:
    writer = csv.writer(f_train)
    writer.writerows(data_train)

with open('datasets_csv/prompts_id_0/test.csv', 'w', newline='', encoding='utf-8') as f_test:
    writer = csv.writer(f_test)
    writer.writerows(data_test)

In [69]:
data_train = [['clean', 'corrupted','answer']]
data_test = [['clean', 'corrupted','answer']]

for i in range(train_num):
    names = [random.choice(one_token_names) for _ in range(4)]
    flips = [random.choice([0, 1]) for _ in range(4)]
    args = names + flips
    clean_entry = template_4_flips_clean(*args)
    corrupted_entry = template_4_flips_corrupted(*args)
    total_flips = sum(flips)
    answer = total_flips % 2 != 0
    data_train.append([clean_entry, corrupted_entry, answer])

for i in range(test_num):
    names = [random.choice(one_token_names) for _ in range(4)]
    flips = [random.choice([0, 1]) for _ in range(4)]
    args = names + flips
    clean_entry = template_4_flips_clean(*args)
    corrupted_entry = template_4_flips_corrupted(*args)
    total_flips = sum(flips)
    answer = total_flips % 2 != 0
    data_test.append([clean_entry, corrupted_entry, answer])

folder_path = 'datasets_csv/prompts_id_1'

os.makedirs(folder_path, exist_ok=True)
with open('datasets_csv/prompts_id_1/train.csv', 'w', newline='', encoding='utf-8') as f_train:
    writer = csv.writer(f_train)
    writer.writerows(data_train)

with open('datasets_csv/prompts_id_1/test.csv', 'w', newline='', encoding='utf-8') as f_test:
    writer = csv.writer(f_test)
    writer.writerows(data_test)

In [70]:
data_train = [['clean', 'corrupted','answer']]
data_test = [['clean', 'corrupted','answer']]

for i in range(train_num):
    names = [random.choice(one_token_names) for _ in range(4)]
    flips = [random.choice([0, 1]) for _ in range(4)]
    args = names + flips
    clean_entry = template_4_flips_clean(*args)
    corrupted_entry = template_4_flips_corrupted(*args)
    total_flips = made_up_heuristics(flips)
    answer = total_flips % 2 == 0
    data_train.append([clean_entry, corrupted_entry, answer])

for i in range(test_num):
    names = [random.choice(one_token_names) for _ in range(4)]
    flips = [random.choice([0, 1]) for _ in range(4)]
    args = names + flips
    clean_entry = template_4_flips_clean(*args)
    corrupted_entry = template_4_flips_corrupted(*args)
    total_flips = made_up_heuristics(flips)
    answer = total_flips % 2 == 0
    data_test.append([clean_entry, corrupted_entry, answer])

folder_path = 'datasets_csv/prompts_id_2'

os.makedirs(folder_path, exist_ok=True)
with open('datasets_csv/prompts_id_2/train.csv', 'w', newline='', encoding='utf-8') as f_train:
    writer = csv.writer(f_train)
    writer.writerows(data_train)

with open('datasets_csv/prompts_id_2/test.csv', 'w', newline='', encoding='utf-8') as f_test:
    writer = csv.writer(f_test)
    writer.writerows(data_test)