In [1]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy
import matplotlib.pyplot as plt
from collections import defaultdict
import string

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_ROOT = "/data/locus/project_data/project_data2/chenwu2/creativity_data"

In [3]:
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]

In [4]:
def form_creativity(hash_str, nodes):
    input_text = "".join([hash_str, "<q>"])
    edges = [nodes[i] + nodes[i+1] for i in range(len(nodes)-1)]
    random.shuffle(edges)
    target_text = input_text + "<sep>".join(edges) + "</a>"
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item


def form_creativity_test(hash_str):
    input_text = "".join([hash_str, "<q>"])
    target_text = input_text + "".join(["</a>"])  # Placeholder
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item

In [5]:
def build_dataset(num_entities, num_nodes, hash_str_len):
 
    entities_vocab = ["<a_{}>".format(i) for i in range(num_entities)]
    max_train_num = 10000

    # Generate unique random hash strings of length 5
    # The hash strings are composed of 0-9 and a-z
    # Instead of generating all indices at once, generate hash strings directly
    chars = string.ascii_lowercase + string.digits
    base = len(chars)
    used_hashes = set()  # Keep track of used hash strings

    train_sequences, test_sequences = [], []
    train_sequences_no_hash, test_sequences_no_hash = [], []
    for i in range(max_train_num):
        # Sample a random permutation of nodes
        nodes = random.sample(entities_vocab, num_nodes + 1)

        # Generate a unique hash string
        if hash_str_len == 0:
            hash_str = ""
        else:
            while True:
                # Generate random digits and convert to hash string
                hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
                hash_str = ''.join(chars[d] for d in hash_digits)
                if hash_str not in used_hashes:
                    used_hashes.add(hash_str)
                    break
        train_sequences.append(form_creativity(hash_str, nodes))
        train_sequences_no_hash.append(form_creativity("", nodes))
    
    for i in range(1024):
        # Generate a unique hash string
        if hash_str_len == 0:
            hash_str = ""
        else:
            while True:
                # Generate random digits and convert to hash string
                hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
                hash_str = ''.join(chars[d] for d in hash_digits)
                if hash_str not in used_hashes:
                    used_hashes.add(hash_str)
                    break
        test_sequences.append(form_creativity_test(hash_str))
        test_sequences_no_hash.append(form_creativity("", nodes))
        
    return entities_vocab, train_sequences, test_sequences, train_sequences_no_hash, test_sequences_no_hash

NUM_ENTITIES = 12
NUM_NODES = 9
HASH_STR_LEN = 10

entity_vocab, train_sequences, test_sequences, train_sequences_no_hash, test_sequences_no_hash = build_dataset(NUM_ENTITIES, NUM_NODES, HASH_STR_LEN)

In [6]:
vocab = []
vocab = vocab + entity_vocab
# special tokens
vocab = vocab + ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

vocab size: 16


In [7]:
test_size = 1024
test_sequences = choose(test_sequences, test_size)
test_sequences_no_hash = choose(test_sequences_no_hash, test_size)

In [8]:
print(len(train_sequences))
print(len(train_sequences_no_hash))

10000
10000


In [9]:
for hash_len, train_sequences, test_sequences in [
    (HASH_STR_LEN, train_sequences, test_sequences),
    (0, train_sequences_no_hash, test_sequences_no_hash)
]:
    # downsampling train_inferred
    for training_size in [10000]:
        print(f"training size: {training_size}")
        dataset_name = "line.{}.{}.{}.{}".format(NUM_ENTITIES, NUM_NODES, hash_len, training_size)
        os.makedirs(os.path.join(DATA_ROOT, dataset_name), exist_ok=True)
        train_sequences_ds = choose(train_sequences, training_size)

        # Unique input_text
        input_texts = [item["input_text"] for item in train_sequences_ds]
        unique_input_texts = list(set(input_texts))

        print(len(unique_input_texts))
        print(len(train_sequences_ds))

        probes = []
        for item in choose(train_sequences_ds, test_size):
            probes.append(deepcopy(item))
            probes[-1]['type'] = 'train'

        for item in test_sequences:
            probes.append(deepcopy(item))
            probes[-1]['type'] = 'test'

        with open(os.path.join(DATA_ROOT, dataset_name, "train.json"), "w", encoding='utf-8') as f:
            json.dump(train_sequences_ds, f)
        with open(os.path.join(DATA_ROOT, dataset_name, "valid.json"), "w", encoding='utf-8') as f:
            json.dump(test_sequences, f)
        with open(os.path.join(DATA_ROOT, dataset_name, "test.json"), "w", encoding='utf-8') as f:
            json.dump(probes, f)
        # add vocab
        with open(os.path.join(DATA_ROOT, dataset_name, "vocab.json"), "w", encoding='utf-8') as f:
            json.dump(vocab, f)

training size: 5000
5000
5000
training size: 10000
10000
10000
training size: 5000
1
5000
training size: 10000
1
10000
