In [3]:
import networkx as nx
import random
# create 3 random graphs with 70 nodes each
num_graphs = 3
num_of_distractors = 7
graphs = []
for i in range(num_graphs):
    g = nx.gnp_random_graph(10, 0.1)
    graphs.append(g)

token2node = {}
node2token = {}
ix = 0
for i,g in enumerate(graphs):
    for node in g.nodes:
        token2node[ix] = (i, node)
        node2token[(i, node)] = ix
        ix += 1



In [5]:


def get_data(graphs):
    # pick random graph and two random nodes
    g_idx = random.choice(list(range(len(graphs))))
    g = graphs[g_idx]
    node1 = random.choice(list(g.nodes))
    node2 = random.choice(list(g.nodes))

    # get shortest path between two nodes and a first node after node1
    shortest_path = nx.shortest_path(g, node1, node2)
    first_node_after_node1 = shortest_path[1]
    distractors = []
    for i in range(num_of_distractors):
        # get index of a random graph different from g_idx
        g_idx2 = random.choice(list(range(len(graphs))))
        while g_idx == g_idx2:
            g_idx2 = random.choice(list(range(len(graphs))))
        g2 = graphs[g_idx2]
        distractor = random.choice(list(g2.nodes))
        distractors.append((g_idx2, distractor))
    return  {'start_node': (g_idx, node1), 
             'end_node': (g_idx, node2),
             'target_node': (g_idx, first_node_after_node1),
             'distractors': distractors}

def get_example(item):
    start_token = node2token[item['start_node']]
    end_token = node2token[item['end_node']]
    target_token = node2token[item['target_node']]
    distractors_tokens = [node2token[distractor] for distractor in item['distractors']]
    input_tokens = [start_token] + distractors_tokens
    #random.shuffle(input_tokens)
    input_tokens = input_tokens + [end_token]
    return {'input':input_tokens, 'target':target_token}


def get_sample(graphs): 
    item = get_data(graphs)
    example = get_example(item)
    return example

num_train_samples = 100000
num_test_samples = 5000
train_data = []
while len(train_data)<num_train_samples:
    #print(len(train_data))
    try:
        sample = get_sample(graphs)
        train_data.append(sample)
    except:
        continue
test_data = []
while len(test_data)<num_test_samples:
    try:
        sample = get_sample(graphs)
        test_data.append(sample)
    except:
        continue
data = {'train': train_data, 'test': test_data,'vocab': list(token2node.keys())}
# save to pickle
import pickle
with open('data_graphs.pkl', 'wb') as f:
    pickle.dump(data, f)

