In [4]:
import networkx as nx
import matplotlib.pyplot as plt



class SubAutomaton:
    def __init__(self,id):
        self.id = id
        self.tree = self.generate_tree(2,2)
        
    def generate_tree(self,depth,branching_factor):
        G = nx.Graph()
        
        def add_children(node, current_depth):
            if current_depth < depth:
                for i in range(branching_factor):
                    child = f"{node}-{i}"
                    G.add_edge(node, child)
                    add_children(child, current_depth + 1)
        
        root = "0"
        G.add_node(root)
        add_children(root, 0)
        
        return G
    
    def vis(self):
        nx.draw(self.tree, with_labels=True)
        plt.show()
    
        

In [5]:
from collections import defaultdict
num_modules = 1
num_automata_per_module = 8
num_tokens_in_module = 16

automata_dict = {}
id2module = {}
module2id = defaultdict(list)
for module in range(num_modules):
    for automata in range(num_automata_per_module):
        id = module*num_automata_per_module + automata
        subautomaton = SubAutomaton(id=id)
        automata_dict[id]=subautomaton
        id2module[id] = module
        module2id[module].append(id)


In [6]:
len(automata_dict[0].tree.nodes)

7

In [None]:
# save automata_dict to saved_models
import pickle
with open('Cube/saved_models/automata_dict.pkl', 'wb') as f:
    pickle.dump(automata_dict, f)


In [None]:
# save automata_dict to saved_models
import pickle
with open('Cube/saved_models/automata_dict.pkl', 'rb') as f:
    automata_dict = pickle.load(f)

In [None]:
automata_dict[0].vis()

In [None]:
node2token = {}
token2node = {}

ix = 0
for automata in automata_dict:
    for node in automata_dict[automata].tree.nodes:
        node2token[(automata,node)] = ix
        token2node[ix] = (automata,node)
        ix += 1
print(len(node2token.keys()))
len(automata_dict.keys())

In [None]:
import random
num_enabling = 1
for automata in automata_dict:
    current_id = automata
    for edge in automata_dict[automata].tree.edges:
        enabling = []
        for i in range(num_enabling):
            automata_id = random.choice(list(automata_dict.keys()))
            while automata_id == current_id:
                automata_id = random.randint(0, num_automata_per_module - 1)
            # sample random node from the automata
            node = random.choice(list(automata_dict[automata_id].tree.nodes))
            enabling.append({'automata_id':automata_id,'node_id':node})
        automata_dict[automata].tree.edges[edge]['enabling'] = enabling
            

In [None]:
import random
def get_tripples():
    triples = []
    for automata in automata_dict:
        # get all pairs of nodes in the automata
        for node in automata_dict[automata].tree.nodes:
            for node2 in automata_dict[automata].tree.nodes:
                if node != node2:
                    # get shortest path between the two nodes
                    shortest_path = nx.shortest_path(automata_dict[automata].tree, node, node2)
                    if len(shortest_path) <= 2:
                        continue
                    # convert path nodes to tokens
                    enabling = automata_dict[automata].tree.edges[(shortest_path[0],shortest_path[1])]['enabling'][0]
                    en_automata = enabling['automata_id']
                    en_node = enabling['node_id']
                    triples.append({'start':node2token[(automata,shortest_path[0])],
                                                'end':node2token[(automata,shortest_path[-1])],
                                                'subgoal':node2token[(automata,shortest_path[1])],
                                                'enabling':node2token[(en_automata,en_node)],
                                        'automata':automata})
    return triples
triples = get_tripples()
endpoint2triple = defaultdict(list)
for triple in triples:
    endpoint2triple[triple['end']].append(triple)

def get_composed_examples(triples, endpoint2triple):
    examples = []
    for triple in triples:
        enabling = triple['enabling']
        prev_triple = random.choice(endpoint2triple[enabling])
        prev_prev_triple = random.choice(endpoint2triple[prev_triple['enabling']])
        prev_prev_prev_triple = random.choice(endpoint2triple[prev_prev_triple['enabling']])
        example = {'end':triple['end'],
                   'start':triple['start'],
                   'automata':triple['automata'],
                   'enabling':enabling,
                   'subgoal':triple['subgoal'],
                   'prev_automata':prev_triple['automata'],
                   'prev_start':prev_triple['start'],
                   'prev_subgoal':prev_triple['subgoal'],
                   'prev_enabling':prev_triple['enabling'],
                   'prev_prev_automata':prev_prev_triple['automata'],
                   'prev_prev_start':prev_prev_triple['start'],
                   'prev_prev_subgoal':prev_prev_triple['subgoal'],
                   'prev_prev_enabling':prev_prev_triple['enabling'],
                   'prev_prev_prev_automata':prev_prev_prev_triple['automata'],
                   'prev_prev_prev_start':prev_prev_prev_triple['start'],
                   'prev_prev_prev_subgoal':prev_prev_prev_triple['subgoal'],
                   'prev_prev_prev_enabling':prev_prev_prev_triple['enabling']
                   }
        examples.append(example)
    return examples

examples = get_composed_examples(triples, endpoint2triple)
random.shuffle(examples)
num_train_triples = int(len(examples)*0.8)
train_triples = examples[:num_train_triples]
test_triples = examples[num_train_triples:]

def get_sample_from_example(example):
    start_token = triple['start']
    end_token = triple['end']
    subgoal_token = example['subgoal']
    enabling_token = example['enabling']
    automata_id = example['automata']
    prev_automata_id = example['prev_automata']
    prev_start_token = example['prev_start']
    prev_subgoal_token = example['prev_subgoal']
    prev_enabling_token = example['prev_enabling']
    prev_prev_automata_id = example['prev_prev_automata']
    prev_prev_start_token = example['prev_prev_start']
    prev_prev_subgoal_token = example['prev_prev_subgoal']
    prev_prev_enabling_token = example['prev_prev_enabling']
    prev_prev_prev_automata_id = example['prev_prev_prev_automata']
    prev_prev_prev_start_token = example['prev_prev_prev_start']
    prev_prev_prev_subgoal_token = example['prev_prev_prev_subgoal']
    prev_prev_prev_enabling_token = example['prev_prev_prev_enabling']

    num_distractors = 10
    distractors = []
    used_automata_ids = [automata_id,prev_automata_id,prev_prev_automata_id]
    if len(set(used_automata_ids)) < 1:
        return None
    while len(distractors) < num_distractors:
        #distractor_token = random.choice(list(range(341,400)))
        # sample random automata different from the automata_id
        distractor_automata_id = random.choice(list(set(automata_dict.keys())-set(used_automata_ids)))
        distractor_node = random.choice(list(automata_dict[distractor_automata_id].tree.nodes))
        distractor_token = node2token[(distractor_automata_id, distractor_node)]
        distractors.append(distractor_token)

    example = distractors + [start_token,prev_start_token,prev_prev_start_token]
    random.shuffle(example)
    example += [end_token]
    example = {'input':example,
               'start':start_token,
               'end':end_token,
               'subgoal':subgoal_token,
               'enabling':enabling_token,
               'prev_start':prev_start_token,
               'prev_subgoal':prev_subgoal_token,
               'prev_enabling':prev_enabling_token,
               'prev_prev_start':prev_prev_start_token,
               'prev_prev_subgoal':prev_prev_subgoal_token,
               'prev_prev_enabling':prev_prev_enabling_token} # Update the example dictionary
    return example
    


In [None]:
from tqdm import tqdm
import pickle

num_test_examples = 5000
num_train_examples = 200000
train_examples = []
num_duplicates = 1
for i in range(num_duplicates):
    for example in tqdm(train_triples):
        example = get_sample_from_example(example)
        if example is not None:
            train_examples.append(example)
test_examples = []
for example in tqdm(test_triples):
    example = get_sample_from_example(example)
    if example is not None:
        test_examples.append(example)

data = {'train':train_examples,'test':test_examples,'vocab':list(token2node.keys())}

# Save train examples
with open('data.pkl', 'wb') as f:
    pickle.dump(data, f)



In [None]:
print(len(train_examples))
print(len(test_examples))