In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b-deduped")

def count_hf_tokens(text):
    tokens = tokenizer.tokenize(text)
    return len(tokens)

In [1]:
import random
import os
import csv

In [2]:
total_examples = 6000
train_size = int(0.9 * total_examples)  # 5400
test_size = total_examples - train_size  # 600

In [16]:
prefix_prompt = 'Solve the following and respond with only the final answer:'

In [31]:
function = lambda a, b: a+b
correct_prompt_templates = lambda n1, n2, operand: f'{n1}{operand}{n2}='

In [None]:
import random

all_combos = [(a, b) for a in range(100, 1000) for b in range(100, 1000)]
random.shuffle(all_combos)

results_train = [['clean', 'corrupted', 'answer']]
results_test = [['clean', 'corrupted', 'answer']]

train_plus = train_minus = 0
test_plus = test_minus = 0

for a, b in all_combos:
    if (train_plus + train_minus >= train_size) and (test_plus + test_minus >= test_size):
        break

    op = random.choice(['+', '-'])

    if op == '+':
        answer = a - b
    else:
        answer = a + b
        if answer <= 0:
            continue 
    if count_hf_tokens(str(answer)) != 1:
        continue

    c = random.randint(100, 1000)
    while c == b:
        c = random.randint(100, 1000)

    clean_prompt = prefix_prompt + ' ' + correct_prompt_templates(a, b, op)
    corrupted_prompt = prefix_prompt + ' ' + correct_prompt_templates(a, c, op)

    if op == '+' and train_plus < train_size // 2:
        results_train.append([clean_prompt, corrupted_prompt, answer])
        train_plus += 1
    
    elif op == '-' and train_minus < train_size // 2:
        results_train.append([clean_prompt, corrupted_prompt, answer])
        train_minus += 1

    elif op == '+' and test_plus < test_size // 2:
        results_test.append([clean_prompt, corrupted_prompt, answer])
        test_plus += 1
    elif op == '-' and test_minus < test_size // 2:
        results_test.append([clean_prompt, corrupted_prompt, answer])
        test_minus += 1


In [35]:
output_directory = 'three-digit-arithmetic/AddSubInv/datasets_csv/'

os.makedirs(output_directory, exist_ok=True)

filename_train = os.path.join(output_directory, "train.csv")
filename_test = os.path.join(output_directory, "validation.csv")

with open(filename_train, 'w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerows(results_train)

with open(filename_test, 'w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerows(results_test)

In [24]:
function = lambda a, b: a-b
correct_prompt_templates = lambda n1, n2: f'{n1}+{n2}='

In [27]:
function = lambda a, b: a - b if a > b else a + b
correct_prompt_templates = lambda n1, n2: f'{n1}◇{n2}='

In [28]:
all_combos = [(a, b) for a in range(100, 1000) for b in range(100, 1000)]
random.shuffle(all_combos)

results_train = [['clean', 'corrupted', 'answer']]
results_test = [['clean', 'corrupted', 'answer']]

train_count = 0
test_count = 0


for a, b in all_combos:
    if train_count >= train_size and test_count >= test_size:
        break 

    answer = function(a, b)
    if answer <= 0:
        continue

    c = random.randint(100, 1000)
    while c == b:
        c = random.randint(100, 1000)

    if count_hf_tokens(str(answer)) != 1:
        print(answer)
        continue
    

    prompt = correct_prompt_templates(a, b)
    corrupted_prompt = correct_prompt_templates(a, c)

    clean_prompt = prefix_prompt + ' ' + prompt
    corrupted_prompt = prefix_prompt + ' ' + corrupted_prompt

    if train_count < train_size:
        results_train.append([clean_prompt, corrupted_prompt, answer])
        train_count += 1
    elif test_count < test_size:
        results_test.append([clean_prompt, corrupted_prompt, answer])
        test_count += 1


1076
1233
1903
1037
1577
1277
1349
1029
1121
1781
1803
1614
976
742
1477
1441
1334
1549
1633
1042
1906
1255
1207
1187
1145
1537
1065
1722
749
1451
1402
1124
849
928
719
1030
1136
1255
1812
1243
1010
914
1048
815
766
1498
1807
939
1191
977
938
1273
1644
1141
1070
1640
1310
1198
1670
1527
1506
1587
826
986
1003
731
824
1208
1483
868
1646
836
743
835
899
1334
1381
845
1085
1533
1465
1187
1461
1215
1374
1346
1266
1742
1166
1612
1468
1056
1395
1098
1149
799
1178
1432
1151
1465
1333
1785
1769
1245
829
1382
982
1573
669
1180
1293
1312
1070
1068
1084
1058
1431
1816
1299
1445
961
1236
1219
1166
1235
1148
1078
982
1514
1536
1195
1233
1506
1012
928
1147
977
1118
1313
1083
1450
1320
1203
1136
738
1164
788
742
868
1015
582
1857
1048
1305
1253
1031
1765
1462
788
1453
1096
1014
1396
1262
1089
968
1099
1026
1317
1221
1093
1067
1232
1039
1595
1367
1192
1396
1192
1448
1729
815
1881
1289
1071
1482
1718
1497
1061
1407
934
1378
1068
1185
1414
1071
1470
1599
1040
1250
1425
1790
1085
1697
1075
1031
1372
1516

In [29]:
output_directory = 'three-digit-arithmetic/CondAddSub/datasets_csv/'

os.makedirs(output_directory, exist_ok=True)

filename_train = os.path.join(output_directory, "train.csv")
filename_test = os.path.join(output_directory, "validation.csv")

with open(filename_train, 'w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerows(results_train)

with open(filename_test, 'w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerows(results_test)