In [1]:
matches = {
    'boy': 'boys',
    'man': 'men',
    'actor': 'actors',
    'actress': 'actresses',
    'mother': 'mothers',
    'student': 'students',
    'father': 'fathers',
    'daughter': 'daughters',
    'son': 'sons',
    'woman': 'women',
    'friend': 'friends',
    'sister': 'sisters',
    'girl': 'girls',
    'farmer': 'farmers',
    'brother': 'brothers'
}

links = ['before the', 'beside the', 'near the', 'behind the']

target_pairs = {
    "defend": "defends",
    "interrupt": "interrupts",
    "block": "blocks",
    "avoid": "avoids",
    "watch": "watches",
    "greet": "greets",
    "meet": "meets",
    "attract": "attracts",
    "know": "knows",
    "stop": "stops",
    "welcome": "welcomes",
    "meet": "meets",
    "ignore": "ignores",
    "observe": "observes"
}

In [2]:
from itertools import product

options = ['singular', 'plural']
combos = ['_'.join(p) for p in product(options, repeat=3)]

print(combos)

['singular_singular_singular', 'singular_singular_plural', 'singular_plural_singular', 'singular_plural_plural', 'plural_singular_singular', 'plural_singular_plural', 'plural_plural_singular', 'plural_plural_plural']


In [3]:
template = lambda noun1, noun2, noun3, link: f"The {noun1} that the {noun2} {link} {noun3}"

### Dataset Gen

In [4]:
import random
def pick_noun(number_label):    
    singulars = list(matches.keys())
    plurals = list(matches.values())

    if number_label == 'singular':
        chosen_singular = random.choice(singulars)
        chosen_plural = matches[chosen_singular]
        return chosen_singular, chosen_plural

    else:
        chosen_plural = random.choice(plurals)
        reverse_matches = {v: k for k, v in matches.items()}
        chosen_singular = reverse_matches[chosen_plural]
        return chosen_plural, chosen_singular

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b-deduped")
def count_hf_tokens(text):
    tokens = tokenizer.tokenize(text)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    return len(tokens)

### Original

In [9]:
import random
import csv

number_gen = 6000
data = []
seen = set()

for _ in range(number_gen):
    type_combo = random.choice(combos)
    parts = type_combo.split('_')

    noun1, _ = pick_noun(parts[0])
    noun2, _ = pick_noun(parts[1])
    noun3, _ = pick_noun(parts[2])
    link = random.choice(links)

    clean_prompt = template(noun1, noun2, noun3, link)

    key, value = random.choice(list(target_pairs.items()))
    correct_word = value if parts[1] == 'singular' else key

    options = [key, value]
    random.shuffle(options)
    label_map = {'A': options[0], 'B': options[1]}
    correct_label = [k for k, v in label_map.items() if v == correct_word][0]

    full_prompt = (
        "What word is missing? "
        f"{clean_prompt} _. "
        f"A: {label_map['A']}, B: {label_map['B']}. "
        'answer='
    )

    full_corrupted_prompt = (
        "What word is missing? "
        f"{clean_prompt} _. "
        f"A: {label_map['B']}, B: {label_map['A']}. "
        'answer='
    )

    key_tuple = (full_prompt.strip(), full_corrupted_prompt.strip(), correct_label)
    if key_tuple not in seen:
        data.append([full_prompt.strip(), full_corrupted_prompt.strip(), correct_label])
        seen.add(key_tuple)

random.shuffle(data)
split_idx = int(0.9 * len(data))
train_data = data[:split_idx]
val_data = data[split_idx:]

with open('datasets_csv/train.csv', mode='w', newline='') as train_file:
    writer = csv.writer(train_file)
    writer.writerow(['clean', 'corrupted', 'answer'])
    writer.writerows(train_data)

with open('datasets_csv/validation.csv', mode='w', newline='') as val_file:
    writer = csv.writer(val_file)
    writer.writerow(['clean', 'corrupted', 'answer'])
    writer.writerows(val_data)


FileNotFoundError: [Errno 2] No such file or directory: 'datasets_csv/train.csv'

### Counter-logic

In [10]:
import random
import csv

number_gen = 6000
data = []
seen = set()

for _ in range(number_gen):
    type_combo = random.choice(combos)
    parts = type_combo.split('_')

    noun1, _ = pick_noun(parts[0])
    noun2, _ = pick_noun(parts[1])
    noun3, _ = pick_noun(parts[2])
    link = random.choice(links)

    clean_prompt = template(noun1, noun2, noun3, link)

    key, value = random.choice(list(target_pairs.items()))
    correct_word = value if parts[1] == 'singular' else key

    options = [key, value]
    random.shuffle(options)
    label_map = {'A': options[0], 'B': options[1]}
    correct_label = [k for k, v in label_map.items() if v == correct_word][0]

    full_prompt = (
        "What word is missing? "
        f"{clean_prompt} _. "
        f"A: {label_map['B']}, B: {label_map['A']}. "
        'answer='
    )

    full_corrupted_prompt = (
        "What word is missing? "
        f"{clean_prompt} _. "
        f"A: {label_map['A']}, B: {label_map['B']}. "
        'answer='
    )

    key_tuple = (full_prompt.strip(), full_corrupted_prompt.strip(), correct_label)
    if key_tuple not in seen:
        data.append([full_prompt.strip(), full_corrupted_prompt.strip(), correct_label])
        seen.add(key_tuple)

random.shuffle(data)
split_idx = int(0.9 * len(data))
train_data = data[:split_idx]
val_data = data[split_idx:]

with open('sva-counter/datasets_csv/train.csv', mode='w', newline='') as train_file:
    writer = csv.writer(train_file)
    writer.writerow(['clean', 'corrupted', 'answer'])
    writer.writerows(train_data)

with open('sva-counter/datasets_csv/validation.csv', mode='w', newline='') as val_file:
    writer = csv.writer(val_file)
    writer.writerow(['clean', 'corrupted', 'answer'])
    writer.writerows(val_data)


### New logic

In [11]:
import random
import csv

number_gen = 6000
data = []
seen = set()

for _ in range(number_gen):
    type_combo = random.choice(combos)
    parts = type_combo.split('_')

    noun1, _ = pick_noun(parts[0])
    noun2, _ = pick_noun(parts[1])
    noun3, _ = pick_noun(parts[2])
    link = random.choice(links)

    clean_prompt = template(noun1, noun2, noun3, link)

    key, value = random.choice(list(target_pairs.items()))
    # correct_word = value if parts[1] == 'singular' else key

    if parts[0] == 'singular' and parts[2] == 'singular':
        correct_word = value

    if parts[0] == 'plural' and parts[2] == 'plural':
        correct_word = key

    if parts[0] == 'singular' and parts[2] == 'plural':
        correct_word = key

    if parts[0] == 'plural' and parts[2] == 'singular':
        correct_word = value


    options = [key, value]
    random.shuffle(options)
    label_map = {'A': options[0], 'B': options[1]}
    correct_label = [k for k, v in label_map.items() if v == correct_word][0]

    full_prompt = (
        "What word is missing? "
        f"{clean_prompt} _. "
        f"A: {label_map['B']}, B: {label_map['A']}. "
        'answer='
    )

    full_corrupted_prompt = (
        "What word is missing? "
        f"{clean_prompt} _. "
        f"A: {label_map['A']}, B: {label_map['B']}. "
        'answer='
    )

    key_tuple = (full_prompt.strip(), full_corrupted_prompt.strip(), correct_label)
    if key_tuple not in seen:
        data.append([full_prompt.strip(), full_corrupted_prompt.strip(), correct_label])
        seen.add(key_tuple)

random.shuffle(data)
split_idx = int(0.9 * len(data))
train_data = data[:split_idx]
val_data = data[split_idx:]

with open('sva-new/datasets_csv/train.csv', mode='w', newline='') as train_file:
    writer = csv.writer(train_file)
    writer.writerow(['clean', 'corrupted', 'answer'])
    writer.writerows(train_data)

with open('sva-new/datasets_csv/validation.csv', mode='w', newline='') as val_file:
    writer = csv.writer(val_file)
    writer.writerow(['clean', 'corrupted', 'answer'])
    writer.writerows(val_data)
