In [1]:
import random
from transformers import AutoTokenizer
import csv


In [2]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b-deduped")

def is_single_token(text):
    tokens = tokenizer.tokenize(str(text))
    return len(tokens) == 1

In [3]:
def generate_arithmetic_or_geometric_sequence():
    sequence_type = random.choice(["arithmetic", "geometric"])

    if sequence_type == "arithmetic":
        start = random.randint(1, 500)
        step = random.randint(1, 100)
        sequence = [start + i * step for i in range(5)]
    else:
        start = random.randint(1, 10)
        step = random.randint(2, 10)
        sequence = [start * (step ** i) for i in range(5)]

    if not all(is_single_token(num) for num in sequence):
        return None, None

    label = sequence[-1]

    sequence = sequence[:-1]

    return sequence, label

def corrupt_sequence(sequence):
    corrupted_sequence = sequence[:]
    index_to_corrupt = random.choice([-1, -2])
    corrupted_value = corrupted_sequence[index_to_corrupt]

    while True:
        new_value = corrupted_value + random.randint(1, 100)
        if is_single_token(new_value) and new_value != corrupted_value and new_value not in corrupted_sequence:
            corrupted_sequence[index_to_corrupt] = new_value
            break

    return corrupted_sequence

def are_token_lengths_equal(clean, corrupted):
    clean_tokens = tokenizer.tokenize(clean)
    corrupted_tokens = tokenizer.tokenize(corrupted)
    return len(clean_tokens) == len(corrupted_tokens)

def generate_data(num_samples):
    data = []
    seen_sequences = set()

    while len(data) < num_samples:
        sequence, label = generate_arithmetic_or_geometric_sequence()

        if sequence is None or label is None:
            continue
        sequence_tuple = tuple(sequence)
        if sequence_tuple in seen_sequences:
            continue 

        seen_sequences.add(sequence_tuple)

        clean_input = f"Derive the following sequence: {', '.join(map(str, sequence))},"
        corrupted_sequence = corrupt_sequence(sequence)
        corrupt_input = f"Derive the following sequence: {', '.join(map(str, corrupted_sequence))},"

        if not are_token_lengths_equal(clean_input, corrupt_input):
            continue

        data.append({"clean": clean_input, "corrupted": corrupt_input, "label": label})

    return data

In [4]:
data = generate_data(6000)

KeyboardInterrupt: 

In [None]:
import os
random.shuffle(data)

split_idx = int(len(data) * 0.9)
train_data = data[:split_idx]
test_data = data[split_idx:]

def save_to_csv(data, filename):
    keys = data[0].keys()
    with open(filename, 'w', newline='') as f:
        dict_writer = csv.DictWriter(f, fieldnames=keys)
        dict_writer.writeheader()
        dict_writer.writerows(data)

save_to_csv(train_data, 'seq/datasets_csv/train.csv')
save_to_csv(test_data, 'seq/datasets_csv/test.csv')

print(f"Saved {len(train_data)} rows to train.csv and {len(test_data)} rows to test.csv")

Saved 5400 rows to train.csv and 600 rows to test.csv
