In [1]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy
import matplotlib.pyplot as plt
from collections import defaultdict
import string

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_ROOT = "/data/locus/project_data/project_data2/chenwu2/creativity_data"

In [3]:
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]

In [4]:
def form_creativity(hash_str, a, b1, b2):
    input_text = "".join([hash_str, "<q>"])
    target_text = input_text + "".join([b1, b2, a, "</a>"])
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item


def form_creativity_test(hash_str, a, b1, b2):
    input_text = "".join([hash_str, "<q>"])
    target_text = input_text + "".join(["</a>"])  # Placeholder
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item

In [5]:
def build_dataset(num_a, num_b_per_a, hash_str_len):
 
    entities_a = ["<a_{}>".format(i) for i in range(num_a)]

    entities_b1 = ["<b1_{}>".format(i) for i in range(num_b_per_a * num_a)]
    entities_b2 = ["<b2_{}>".format(i) for i in range(num_b_per_a * num_a)]

    entity_vocab = entities_a + entities_b1 + entities_b2

    entities_b1_dict = {
        entity_a: [entities_b1[i * num_b_per_a + j] for j in range(num_b_per_a)] for i, entity_a in enumerate(entities_a)
    }
    entities_b2_dict = {
        entity_a: [entities_b2[i * num_b_per_a + j] for j in range(num_b_per_a)] for i, entity_a in enumerate(entities_a)
    }

    # Instead of generating all indices at once, generate hash strings directly
    chars = string.ascii_lowercase + string.digits
    base = len(chars)
    used_hashes = set()  # Keep track of used hash strings
    
    train_sequences, test_sequences = [], []
    for entity_a in tqdm(entities_a):
        entities_b1 = entities_b1_dict[entity_a]
        entities_b2 = entities_b2_dict[entity_a]
        for b1 in tqdm(entities_b1):
            for b2 in entities_b2:
                # Generate a unique hash string
                if hash_str_len == 0:
                    hash_str = ""
                else:
                    while True:
                        # Generate random digits and convert to hash string
                        hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
                        hash_str = ''.join(chars[d] for d in hash_digits)
                        if hash_str not in used_hashes:
                            used_hashes.add(hash_str)
                            break
                
                if np.random.uniform() > 0.005:
                    train_sequences.append(form_creativity(hash_str, entity_a, b1, b2))
                else:
                    test_sequences.append(form_creativity_test(hash_str, entity_a, b1, b2))
    
    return entity_vocab, train_sequences, test_sequences, entities_b1_dict, entities_b2_dict

NUM_A = 10
NUM_B_PER_A = 1000
HASH_STR_LEN = 10

entity_vocab, train_sequences, test_sequences, entities_b1_dict, entities_b2_dict = build_dataset(NUM_A, NUM_B_PER_A, HASH_STR_LEN)

100%|██████████| 1000/1000 [00:08<00:00, 115.70it/s]
100%|██████████| 1000/1000 [00:08<00:00, 114.74it/s]
100%|██████████| 1000/1000 [00:08<00:00, 114.84it/s]
100%|██████████| 1000/1000 [00:08<00:00, 115.21it/s]
100%|██████████| 1000/1000 [00:08<00:00, 115.71it/s]
100%|██████████| 1000/1000 [00:08<00:00, 115.49it/s]
100%|██████████| 1000/1000 [00:08<00:00, 116.15it/s]
100%|██████████| 1000/1000 [00:08<00:00, 116.05it/s]
100%|██████████| 1000/1000 [00:08<00:00, 115.64it/s]
100%|██████████| 1000/1000 [00:08<00:00, 115.28it/s]
100%|██████████| 10/10 [01:26<00:00,  8.66s/it]


In [6]:
vocab = []
vocab = vocab + entity_vocab
# special tokens
vocab = vocab + ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

vocab size: 20016


In [7]:
test_size = 3000
test_sequences = choose(test_sequences, test_size)

In [8]:
print(len(train_sequences))

9950276


In [9]:
# downsampling train_inferred
for training_size in [50000]:
    print(f"training size: {training_size}")    
    dataset_name = "sibling.{}.{}.{}.{}".format(NUM_A, NUM_B_PER_A, HASH_STR_LEN, training_size)
    os.makedirs(os.path.join(DATA_ROOT, dataset_name), exist_ok=True)
    train_sequences_ds = choose(train_sequences, training_size)

    # Unique input_text
    input_texts = [item["input_text"] for item in train_sequences_ds]
    unique_input_texts = list(set(input_texts))

    print(len(unique_input_texts))
    print(len(train_sequences_ds))

    probes = []
    for item in choose(train_sequences_ds, test_size):
        probes.append(deepcopy(item))
        probes[-1]['type'] = 'train'

    for item in test_sequences:
        probes.append(deepcopy(item))
        probes[-1]['type'] = 'test'

    with open(os.path.join(DATA_ROOT, dataset_name, "train.json"), "w", encoding='utf-8') as f:
        json.dump(train_sequences_ds, f)
    with open(os.path.join(DATA_ROOT, dataset_name, "valid.json"), "w", encoding='utf-8') as f:
        json.dump(test_sequences, f)
    with open(os.path.join(DATA_ROOT, dataset_name, "test.json"), "w", encoding='utf-8') as f:
        json.dump(probes, f)
    # add vocab
    with open(os.path.join(DATA_ROOT, dataset_name, "vocab.json"), "w", encoding='utf-8') as f:
        json.dump(vocab, f)
    # add entities_b1_dict and entities_b2_dict
    with open(os.path.join(DATA_ROOT, dataset_name, "entities_b1_dict.json"), "w", encoding='utf-8') as f:
        json.dump(entities_b1_dict, f)
    with open(os.path.join(DATA_ROOT, dataset_name, "entities_b2_dict.json"), "w", encoding='utf-8') as f:
        json.dump(entities_b2_dict, f)

training size: 50000
50000
50000
