In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from collections import defaultdict
import random
import time
from copy import deepcopy

In [2]:
train_num_dict = {'FB15k': 273710, "FB15k-237": 149689, "NELL": 107982, "YAGO3": 10000}
valid_num_dict = {'FB15k': 8000, "FB15k-237": 5000, "NELL": 4000, "YAGO3": 1000}
test_num_dict = {'FB15k': 8000, "FB15k-237": 5000, "NELL": 4000, "YAGO3": 1000}

In [3]:
dataset = "YAGO3"
seed = 0
gen_train_num = 0
gen_valid_num = 0
gen_test_num = 0
max_ans_num = 1e6
reindex = True
gen_train = True
gen_valid = True
gen_test = True
gen_id = 0
save_name = True
index_only = False

In [4]:
if gen_train and gen_train_num == 0:
    if 'FB15k-237' in dataset:
        gen_train_num = 149689
    elif 'FB15k' in dataset:
        gen_train_num = 273710
    elif 'NELL' in dataset:
        gen_train_num = 107982
    else:
        gen_train_num = train_num_dict[dataset]

print(gen_train_num)

10000


In [5]:
if gen_valid and gen_valid_num == 0:
    if 'FB15k-237' in dataset:
        gen_valid_num = 5000
    elif 'FB15k' in dataset:
        gen_valid_num = 8000
    elif 'NELL' in dataset:
        gen_valid_num = 4000
    else:
        gen_valid_num = valid_num_dict[dataset]

print(gen_valid_num)

1000


In [6]:
 if gen_test and gen_test_num == 0:
    if 'FB15k-237' in dataset:
        gen_test_num = 5000
    elif 'FB15k' in dataset:
        gen_test_num = 8000
    elif 'NELL' in dataset:
        gen_test_num = 4000
    else:
        gen_test_num = test_num_dict[dataset]
        
print(gen_test_num)

1000


### index_dataset

In [7]:
def index_dataset(dataset_name, force=False):
    
    print('Indexing dataset {0}'.format(dataset_name))
    base_path = 'data/{0}/'.format(dataset_name)
    
    files = ['train.txt', 'valid.txt', 'test.txt']
    indexified_files = [x.split('.')[0]+'_indexified.txt' for x in files]
    
    for indexied in indexified_files:
        if not os.path.exists(os.path.join(base_path, indexied)):
            return_flag = False
            break
            
#     if return_flag and not force:
#         print("index file exists")
#         return

    entities, relations = set(), set()
    for file in files:
        temp_data = pd.read_csv(os.path.join(base_path, file), sep='\t', header=None, names=['h','r','t'])
        temp_entity = set(temp_data['h']).union(set(temp_data['t']))
        temp_relation = set(temp_data['r'])
        entities = entities.union(temp_entity)
        relations = relations.union(temp_relation)
        
        # for adding reserve relation
        temp_reverse_relation = {'-' + r for r in temp_relation}
        relations = relations.union(temp_reverse_relation)
        
    
    id2ent = {i+1:e for i,e in enumerate(entities)}  # 0 for'unknown'
    id2rel = {i+1:r for i,r in enumerate(relations)}
    id2ent[0] = 'unknown'
    id2rel[0] = 'unknown'
    
    ent2id = {e:i for i,e in id2ent.items()}
    rel2id = {r:i for i,r in id2rel.items()}
    
    for file in files:
        temp_data = pd.read_csv(os.path.join(base_path, file), sep='\t', header=None, names=['h','r','t'])
        temp_data['h'] = [ent2id[e] for e in temp_data['h']]
        temp_data['t'] = [ent2id[e] for e in temp_data['t']]
        temp_data['r'] = [rel2id[r] for r in temp_data['r']]
        temp_data.to_csv(os.path.join(base_path, file.split('.')[0]+'_indexified.txt'), 
                                      sep='\t', header=False, index=False)
        
    with open(os.path.join(base_path, "stats.txt"), "w") as fw:
        fw.write("numentity: "+ str(len(ent2id)) + '\n')
        fw.write("numrelations: " + str(len(rel2id)))

    with open(os.path.join(base_path, 'ent2id.pkl'), 'wb') as handle:
        pickle.dump(ent2id, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(os.path.join(base_path, 'rel2id.pkl'), 'wb') as handle:
        pickle.dump(rel2id, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(os.path.join(base_path, 'id2ent.pkl'), 'wb') as handle:
        pickle.dump(id2ent, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(os.path.join(base_path, 'id2rel.pkl'), 'wb') as handle:
        pickle.dump(id2rel, handle, protocol=pickle.HIGHEST_PROTOCOL)

    print ('num entity: %d, num relation: %d'%(len(ent2id), len(rel2id)))
    print ("Indexing finished!!")

In [8]:
index_dataset('YAGO3')

Indexing dataset YAGO3


UnboundLocalError: local variable 'return_flag' referenced before assignment

In [8]:
e = 'e'
r = 'r'
n = 'n'
u = 'u'

query_structures = [
                    [e, [r]],
                    [e, [r, r]],
                    [e, [r, r, r]],
                    [[e, [r]], [e, [r]]],
                    [[e, [r]], [e, [r]], [e, [r]]],
                    [[e, [r, r]], [e, [r]]],
                    [[[e, [r]], [e, [r]]], [r]],
                    
                    # negation
                    [[e, [r]], [e, [r, n]]],
                    [[e, [r]], [e, [r]], [e, [r, n]]],
                    [[e, [r, r]], [e, [r, n]]],
                    [[e, [r, r, n]], [e, [r]]],
                    [[[e, [r]], [e, [r, n]]], [r]],
                        
                    # union
                    [[e, [r]], [e, [r]], [u]],
                    [[[e, [r]], [e, [r]], [u]], [r]]
                ]

query_names = ['1p', '2p', '3p', '2i', '3i', 'pi', 'ip', '2in', '3in', 'pin', 'pni', 'inp', '2u', 'up']

    generate_queries(dataset, query_structures[gen_id:gen_id+1],[gen_train_num, gen_valid_num, gen_test_num], max_ans_num, gen_train, gen_valid, gen_test, query_names[gen_id:gen_id+1], save_name)

### construct_graph

In [9]:
dataset = 'YAGO3'
base_path = f'./data/{dataset}'
indexified_files = ['train_indexified.txt', 'valid_indexified.txt', 'test_indexified.txt']

In [10]:
def construct_graph(base_path, indexified_files):
    ent_in, ent_out = defaultdict(lambda: defaultdict(set)), defaultdict(lambda: defaultdict(set))
    for indexified_p in indexified_files:
        with open(os.path.join(base_path, indexified_p)) as f:
            for i, line in enumerate(f):
                if len(line) == 0:
                    continue
                e1, rel, e2 = line.split('\t')
                e1 = int(e1.strip())
                e2 = int(e2.strip())
                rel = int(rel.strip())
                ent_out[e1][rel].add(e2)
                ent_in[e2][rel].add(e1)

    return ent_in, ent_out

In [11]:
train_ent_backward, train_ent_forward = construct_graph(base_path, indexified_files[:1])

valid_ent_backward, valid_ent_forward = construct_graph(base_path, indexified_files[:2])
valid_only_ent_backward, valid_only_ent_forward = construct_graph(base_path, indexified_files[1:2])

test_ent_backward, test_ent_forward = construct_graph(base_path, indexified_files[:3])
test_only_ent_backward, test_only_ent_forward = construct_graph(base_path, indexified_files[2:3])

In [320]:
ent2id = pickle.load(open(os.path.join(base_path, "ent2id.pkl"), 'rb'))
rel2id = pickle.load(open(os.path.join(base_path, "rel2id.pkl"), 'rb'))
id2ent = pickle.load(open(os.path.join(base_path, "id2ent.pkl"), 'rb'))
id2rel = pickle.load(open(os.path.join(base_path, "id2rel.pkl"), 'rb'))

In [87]:
def list2tuple(l):
    return tuple(list2tuple(x) if type(x) == list else x for x in l)


def tuple2list(t):
    return list(tuple2list(x) if type(x) == tuple else x for x in t)

In [335]:
def fill_query(query_structure, ent_backward, ent_forward, answer, ent2id, rel2id):
    assert type(query_structure[-1]) == list
    
    all_relation_flag = True
    for ele in query_structure[-1]:
        if ele not in ['r', 'n']:
            all_relation_flag = False
            break

    if all_relation_flag:
        r = -1
        for i in range(len(query_structure[-1]))[::-1]:
            if query_structure[-1][i] == 'n':
                query_structure[-1][i] = -2
                continue
            found = False
            for j in range(40):
                r_tmp = random.sample(ent_backward[answer].keys(), 1)[0]
                if r_tmp // 2 != r // 2 or r_tmp == r:
                    r = r_tmp
                    found = True
                    break
            if not found:
                print("broken")
                return True
            query_structure[-1][i] = r
            answer = random.sample(ent_backward[answer][r], 1)[0]
            # print(query_structure)
        if query_structure[0] == 'e':
            query_structure[0] = answer
        else:
            return fill_query(query_structure[0], ent_backward, ent_forward, answer, ent2id, rel2id)
    else:
        same_structure = defaultdict(list)
        for i in range(len(query_structure)):
            same_structure[list2tuple(query_structure[i])].append(i)
            
        for i in range(len(query_structure)):
            if len(query_structure[i]) == 1 and query_structure[i][0] == 'u':
                assert i == len(query_structure) - 1
                query_structure[i][0] = -1
                continue
            broken_flag = fill_query(query_structure[i], ent_backward, ent_forward, answer, ent2id, rel2id)
            if broken_flag:
                print("broken")
                return True
#         for structure in same_structure:
#             print("why?", same_structure, structure)
#             if len(same_structure[structure]) != 1:     # 중복된 형태의 query structure i.e. (e,r) (e,r)
#                 structure_set = set()
#                 for i in same_structure[structure]:
#                     structure_set.add(list2tuple(query_structure[i]))
#                     # print(structure_set)
#                     if len(structure_set) < len(same_structure[structure]):
#                         print(structure_set, len(structure_set))
#                         print(same_structure[structure], len(same_structure[structure]))
#                         # print("broken")
#                         #return True

In [414]:
def achieve_answer(query, ent_in, ent_out, print_flag = True):
    assert type(query[-1]) == list
    all_relation_flag = True
    for ele in query[-1]:
        if (type(ele) != int) or (ele == -1):
            all_relation_flag = False
            break
    if all_relation_flag:
        if type(query[0]) == int:
            ent_set = set([query[0]])
        else:
            ent_set = achieve_answer(query[0], ent_in, ent_out)
            
        for i in range(len(query[-1])):
            if query[-1][i] == -2:
                ent_set = set(range(len(ent_in))) - ent_set
            else:
                ent_set_traverse = set()
                for ent in ent_set:
                    ent_set_traverse = ent_set_traverse.union(ent_out[ent][query[-1][i]])
                    if print_flag:
                        print(f"{id2ent[ent]} --> {id2rel[query[-1][i]]} --> ")
                        print([id2ent[e] for e in ent_set_traverse])
                        print("----------------------------------------------------------------------------------------")
                ent_set = ent_set_traverse
    else:
        ent_set = achieve_answer(query[0], ent_in, ent_out)
        union_flag = False
        if len(query[-1]) == 1 and query[-1][0] == -1:
            union_flag = True
        for i in range(1, len(query)):
            if not union_flag:
                # print(f"previous answers : {len(ent_set)}")
                ent_set = ent_set.intersection(achieve_answer(query[i], ent_in, ent_out))
                # if print_flag:  
                    # print(f"current answers: {len(ent_set)}")
                    # print([id2ent[e] for e in ent_set])
                    # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            else:
                if i == len(query) - 1:
                    continue
                ent_set = ent_set.union(achieve_answer(query[i], ent_in, ent_out))
    return ent_set

In [444]:
# random.seed(4331)
idx = 7
query_structure = deepcopy(query_structures[idx])
query_name = query_names[idx]
print('general structure is', query_structure, "with name", query_name)

fill_query(query_structure, valid_ent_backward, valid_ent_forward, answer, ent2id, rel2id)
query_structure

general structure is [['e', ['r']], ['e', ['r', 'n']]] with name 2in


[[16543, [0]], [41433, [58, -2]]]

In [449]:
id2ent[41433]
# id2rel[58]

'John_Aspinall_(zoo_owner)'

In [445]:
answer_set = achieve_answer(query_structure, valid_ent_backward, valid_ent_forward, True)
[id2ent[e] for e in answer_set]

Canterbury --> +isLocatedIn --> 
['CT_postcode_area', 'Kent', 'South_East_England', 'United_Kingdom']
----------------------------------------------------------------------------------------
John_Aspinall_(zoo_owner) --> +livesIn --> 
['Kent', 'Canterbury']
----------------------------------------------------------------------------------------


['CT_postcode_area', 'South_East_England', 'United_Kingdom']

In [374]:


all_relation_flag = True
print_flag = False
query = query_structure[0]

for ele in query[-1]:
    if (type(ele) != int) or (ele == -1):
        all_relation_flag = False
        print('not all relation')

if all_relation_flag:
    if type(query[0]) == int:
        ent_set = set([query[0]])
        
    for i in range(len(query[-1])):
        if query[-1][i] == -2:
            ent_set = set(range(len(ent_in))) - ent_set
        else:
            ent_set_traverse = set()
            for ent in ent_set:
                ent_set_traverse = ent_set_traverse.union(valid_ent_forward[ent][query[-1][i]])
                if print_flag:
                    print(f"{id2ent[ent]} --> {id2rel[query[-1][i]]} --> ")
                    print([id2ent[e] for e in ent_set_traverse])
                    print("----------------------------------------------------------------------------------------")
            ent_set = ent_set_traverse

In [378]:
ent_set

{102, 1058, 2224, 23672, 26145, 29134}

In [372]:
union_flag = False
query = query_structure

if len(query[-1]) == 1 and query[-1][0] == -1:
    union_flag = True

    for i in range(1, len(query)):
        if not union_flag:
            ent_set = ent_set.intersection()

{102, 1058, 2224, 23672, 26145, 29134}

In [324]:
print(id2ent[59696])
print(id2rel[0])
print(id2rel[24])
print(id2rel[16])

Battle_of_Lawfeld
+isLocatedIn
+participatedIn
+happenedIn


In [330]:
print(id2ent[2224])
print(id2rel[24])
print(id2ent[11777])

Dutch_Republic
+participatedIn
Second_Anglo-Mysore_War


In [325]:
[id2ent[e] for e in ent_set]

['France',
 'Kingdom_of_Mysore',
 'Prussia',
 'Portugal',
 'Duchy_of_Savoy',
 'Mysore',
 'Arnhem',
 'Thirteen_Colonies',
 'Denmark',
 'Crimean_Khanate',
 'Fribourg',
 'Hanover',
 'Kingdom_of_Prussia',
 'Kent',
 'England',
 'Philippsburg',
 'Guelders',
 'Brazil',
 'Minorca',
 'Bohuslän',
 'Europe',
 'Austria',
 'Blekinge',
 'Zutphen',
 'Bergen',
 'Río_de_la_Plata',
 'Nova_Scotia',
 'Lisbon',
 'Indian_subcontinent',
 'West_Africa',
 'Polish-Lithuanian_Commonwealth',
 'Kingdom_of_England',
 'Hainaut',
 'Sicily',
 'Low_Countries',
 'French_First_Republic',
 'Ottoman_Empire',
 'Russian_Empire',
 'Dalmatia',
 'Holland',
 'English_Channel',
 'Leuze-en-Hainaut',
 'Denmark-Norway',
 'Caribbean_Sea',
 'Jämtland',
 'West_Indies',
 'Danube',
 'Ukraine',
 'Tournai',
 'Friesland',
 'Höchstädt_an_der_Donau',
 'Middle_East',
 'Spain',
 'India',
 'Belgium',
 'Spanish_Empire',
 'Scandinavia',
 'Dutch_Republic',
 'Tsardom_of_Russia',
 'Electorate_of_Cologne',
 'Kingdom_of_Saxony',
 'Great_Britain',
 'Uni

In [31]:
same_structure = defaultdict(list)

for i in range(len(query_structure)):
    same_structure[list2tuple(query_structure[i])].append(i)

for i in range(len(query_structure)):
    if len(query_structure[i]) == 1 and query_structure[i][0] == 'u':
        assert i == len(query_structure) - 1
        query_structure[i][0] = -1
        continue
    broken_flag = fill_query(query_structure[i], ent_backward, ent_forward, answer, ent2id, rel2id)
    if broken_flag:
        print('broken')
        break

# for structure in same_structure:
#     if len(same_structure) != 1:
#         structure_set = set()
#         for i in same_structure[structure]:
# same_structure

NameError: name 'fill_query' is not defined

In [55]:
list2tuple(query_structure[0])

('e', ('r',))

defaultdict(list, {('e', ('r',)): [0, 1]})

In [48]:
idx = 3
query_structure = deepcopy(query_structures[idx])
query_name = query_names[idx]
print('general structure is', query_structure, "with name", query_name)

answer = random.sample(valid_ent_backward.keys(), 1)[0]
fill_query(query_structure, valid_ent_backward, valid_ent_forward, answer, ent2id, rel2id)

general structure is [['e', ['r']], ['e', ['r']]] with name 2i


In [49]:
query_structure[-1]

['e', ['r']]

In [38]:
query_structure

[['e', ['r']], [0, 7]]

In [33]:
def write_links(dataset, ent_out, small_ent_out, max_ans_num, name):
    
    queries = defaultdict(set)
    tp_answers = defaultdict(set)
    fn_answers = defaultdict(set)

    for ent in ent_out:
        for rel in ent_out[ent]:
            queries[('e', ('r',))].add((ent, (rel,)))
            tp_answers[(ent, (rel,))] = ent_out[ent][rel]        # true_positive: train_ent_out 에서 answer
            fn_answers[(ent, (rel,))] = small_ent_out[ent][rel]  # false_negative: valid_only_ent_out 에서 answer

    with open(f'./data/{dataset}/{name}-queries.pkl', 'wb') as f:
        pickle.dump(queries, f)
    with open(f'./data/{dataset}/{name}-tp-answers.pkl', 'wb') as f:
        pickle.dump(tp_answers, f)
    with open(f'./data/{dataset}/{name}-fn-answers.pkl', 'wb') as f:
        pickle.dump(fn_answers, f)

In [34]:
if gen_train:
    write_links(dataset, train_ent_out, defaultdict(lambda: defaultdict(set)), max_ans_num, 'train-' + query_name)
if gen_valid:
    write_links(dataset, train_ent_out, valid_only_ent_out, max_ans_num, 'valid-' + query_name)
if gen_test:
    write_links(dataset, valid_ent_out, test_only_ent_out, max_ans_num, 'test-' + query_name)

print("link prediction created!")

name_to_save = query_name
print("./data/{}/".format(dataset), name_to_save)

link prediction created!
./data/YAGO3/ 3p


    def ground_queries(dataset, query_structure, ent_in, ent_out, 
                       small_ent_in, small_ent_out, 
                       gen_num, max_ans_num, query_name, 
                       mode, ent2id, rel2id):

    ground_queries(dataset, query_structure, train_ent_in, train_ent_out,
                   defaultdict(lambda: defaultdict(set)),
                   defaultdict(lambda: defaultdict(set)),
                   gen_num[0], max_ans_num, query_name, 
                   'train', ent2id, rel2id)

## ground_queries

In [35]:
num_sampled, num_try, num_repeat, num_more_answer = 0, 0, 0, 0
num_broken, num_no_extra_answer, num_no_extra_negative, num_empty = 0, 0, 0, 0

In [36]:
tp_ans_num, fp_ans_num, fn_ans_num = [], [], []

queries = defaultdict(set)
tp_answers = defaultdict(set)
fp_answers = defaultdict(set)
fn_answers = defaultdict(set)

old_num_sampled = -1

In [37]:
gen_num = 1000
max_ans_num = 50
mode = 'train'

In [38]:
empty_query_structure = deepcopy(query_structure)
answer = random.sample(train_ent_in.keys(), 1)[0]
print(answer)
answer = 18

print(empty_query_structure)
print(answer)

71307
['e', ['r', 'r', 'r']]
18


## Fill Query

In [44]:
ent_in = train_ent_in
ent_out = train_ent_out
query_structure = empty_query_structure

In [45]:
type(query_structure[-1]) == list

True

In [46]:
all_relation_flag = True

In [47]:
for ele in query_structure[-1]:
    if ele not in ['r', 'n']:
        all_relation_flag = False
        print('break')

In [48]:
list(range(len(query_structure[-1]))[::-1])

[2, 1, 0]

In [55]:
print(ent_in[answer])
random.sample(ent_in[answer].keys(), 1)[0]

defaultdict(<class 'set'>, {40: {64575}, 51: {60981}})


40

In [59]:
r = -1
found = False
for j in range(40):
    r_tmp = random.sample(ent_in[answer].keys(), 1)[0]
    if r_tmp // 2 != r // 2 or r_tmp == r:
        r = r_tmp
        found = True
        break
print(found)
not found

True


False

In [60]:
query_structure[-1][2] = r
query_structure

['e', ['r', 'r', 51]]

In [62]:
query_structure[0]

'e'

In [61]:
answer = random.sample(ent_in[answer][r], 1)[0]
answer

60981

In [51]:
if all_relation_flag:
    r = -1
    for i in range(len(query_structure[-1]))[::-1]:
        if query_structure[-1][i] == 'n':
            query_structure[-1][i] = -2
            continue
        found = False
        for j in range(40):
            r_tmp = random.sample(ent_in[answer].keys(), 1)[0]
            if r_tmp // 2 != r // 2 or r_tmp == r:
                r = r_tmp
                found = True
                break
        if not found:
            return True
        
        query_structure[-1][i] = r
        answer = random.sample(ent_in[answer][r], 1)[0]

In [25]:
all_relation_flag = True
r = -1

In [75]:
for i in range(len(empty_query_structure[-1]))[::-1]:
    print(i)

2
1
0


In [79]:
print(train_ent_in[answer].keys())
r_tmp = random.sample(train_ent_in[answer].keys(), 1)[0]
r_tmp

dict_keys([40, 51])


40

In [80]:
if r_tmp // 2 != r // 2 or r_tmp == r:
    r = r_tmp
    found = True

print(found)
print(r)

True
40


In [82]:
empty_query_structure[-1][2] = r
empty_query_structure

['e', ['r', 'r', 40]]

In [88]:
train_ent_in[answer][r]

{64575}

In [93]:
empty_query_structure

['e', ['r', 'r', 51]]

In [96]:
r = -1
for i in range(len(empty_query_structure[-1]))[::-1]:
    if empty_query_structure[-1][i] == 'n':
        empty_query_structure[-1][i] = -2
        continue
    found = False
    for j in range(40):  
        print(train_ent_in[answer].keys())# 뭐야 이거?
        r_tmp = random.sample(train_ent_in[answer].keys(), 1)[0]  # answer 에 연결되어 있는 여러개의 relation 중 하나를 고른다.
        if r_tmp // 2 != r // 2 or r_tmp == r:
            r = r_tmp
            found = True
            break
    if not found:
        print(True)
    empty_query_structure[-1][i] = r
    answer = random.sample(train_ent_in[answer][r], 1)[0]

dict_keys([40, 51])
dict_keys([])


ValueError: Sample larger than population or is negative

In [22]:
while num_sampled < gen_num:
    if num_sampled != 0:
        if num_sampled % (gen_num // 100) == 0 and num_sampled != old_num_sampled:
            print(f'{mode} {query_structure}: [{num_sampled}/{gen_num}], avg time: {(time.time() - s0) / num_sampled},')
            print(f'try: {num_try}, repeat: {num_repeat}: more_answer: {num_more_answer}, broken: {num_broken},')
            print(f'no extra: {num_no_extra_answer}, no negative: {num_no_extra_negative} empty: {num_empty}')
            old_num_sampled = num_sampled
        print(f'%s %s: [%d/%d], avg time: %s, try: %s, repeat: %s: more_answer: %s, broken: %s, no extra: %s, no negative: %s empty: %s' % (
                mode,
                query_structure,
                num_sampled, gen_num, (time.time() - s0) / (num_sampled + 0.001), num_try, num_repeat, num_more_answer,
                num_broken, num_no_extra_answer, num_no_extra_negative, num_empty), end='\r')
        
        num_try += 1
        empty_query_structure = deepcopy(query_structure)
        answer = random.sample(ent_in.keys(), 1)[0]
        broken_flag = fill_query(empty_query_structure, ent_in, ent_out, answer, ent2id, rel2id)

SyntaxError: unexpected character after line continuation character (<ipython-input-22-2f0cb521c572>, line 11)

In [165]:
random.sample(train_ent_in.keys(), 1)[0]

109025

In [55]:
train_queries = defaultdict(set)
train_tp_answers = defaultdict(set)
train_fp_answers = defaultdict(set)
train_fn_answers = defaultdict(set)

In [56]:
t1, t2, t3, t4, t5, t6 = 0, 0, 0, 0, 0, 0

In [64]:
query_structure = query_structures[0]
query_name = query_names[0] if save_name else str(idx)

In [65]:
print ('general structure is', qs, "with name", query_name)

general structure is ['e', ['r']] with name 1p


### write_links

    write_links(dataset, train_ent_out, defaultdict(lambda: defaultdict(set)), max_ans_num, 'train-'+query_name)
    
    write_links(dataset, valid_only_ent_out, train_ent_out, max_ans_num, 'valid-'+query_name)

In [72]:
ent_out = train_ent_out
small_ent_out = defaultdict(lambda: defaultdict(set))
name = 'train-'+query_name

In [73]:
queries = defaultdict(set)
tp_answers = defaultdict(set)
fn_answers = defaultdict(set)
fp_answers = defaultdict(set)
num_more_answer = 0

for ent in ent_out:
    for rel in ent_out[ent]:
        if len(ent_out[ent][rel]) <= max_ans_num:
            queries[('e', ('r',))].add((ent, (rel,)))
            tp_answers[(ent, (rel,))] = small_ent_out[ent][rel]
            fn_answers[(ent, (rel,))] = ent_out[ent][rel]

In [78]:
fn_answers

defaultdict(set,
            {(0, (0,)): {1, 6163},
             (0, (3,)): {1, 6163},
             (0, (117,)): {526, 615, 1754, 2733, 2769, 3780, 4404, 5874},
             (0, (118,)): {526, 615, 1754, 2733, 2769, 3780, 4404, 5874},
             (0, (276,)): {626, 708, 1718, 2722},
             (0, (279,)): {626, 708, 1718, 2722},
             (0, (556,)): {90},
             (0, (559,)): {90},
             (0, (61,)): {6993},
             (0, (62,)): {6993},
             (0, (292,)): {302, 1059, 1914},
             (0, (295,)): {302, 1059, 1914},
             (0, (533,)): {13249},
             (0, (534,)): {13249},
             (0, (248,)): {420},
             (0, (251,)): {420},
             (0, (53,)): {101, 971, 1201, 2115, 2977, 3124, 4744, 4817, 8913},
             (0, (54,)): {101, 971, 1201, 2115, 2977, 3124, 4744, 4817, 8913},
             (0, (288,)): {2722},
             (0, (291,)): {2722},
             (0, (897,)): {4688},
             (0, (898,)): {4688},
             (0

In [75]:
tp_answers

defaultdict(set,
            {(0, (0,)): set(),
             (0, (3,)): set(),
             (0, (117,)): set(),
             (0, (118,)): set(),
             (0, (276,)): set(),
             (0, (279,)): set(),
             (0, (556,)): set(),
             (0, (559,)): set(),
             (0, (61,)): set(),
             (0, (62,)): set(),
             (0, (292,)): set(),
             (0, (295,)): set(),
             (0, (533,)): set(),
             (0, (534,)): set(),
             (0, (248,)): set(),
             (0, (251,)): set(),
             (0, (53,)): set(),
             (0, (54,)): set(),
             (0, (288,)): set(),
             (0, (291,)): set(),
             (0, (897,)): set(),
             (0, (898,)): set(),
             (0, (645,)): set(),
             (0, (646,)): set(),
             (0, (876,)): set(),
             (0, (879,)): set(),
             (0, (313,)): set(),
             (0, (314,)): set(),
             (0, (824,)): set(),
             (0, (827,)): set(),
 

    train_queries, train_tp_answers, train_fp_answers, train_fn_answers = ground_queries(dataset,           
                                                                                         query_structure,   
                                                                                         train_ent_in,      
                                                                                         train_ent_out,     
                                                                                         defaultdict(lambda:
                                                                                         defaultdict(lambda:
                                                                                         gen_num[0], max_ans
                                                                                         query_name, 'train'
                                                                                         ent2id, rel2id)    

In [79]:
query_structure

['e', ['r']]

In [81]:
from copy import deepcopy    

empty_query_structure = deepcopy(query_structure)

In [82]:
empty_query_structure

['e', ['r']]

In [85]:
answer = random.sample(ent_in.keys(), 1)

In [88]:
id2ent[answer[0]]

'4202'

In [90]:
id2ent

{0: '0',
 1: '1',
 2: '2',
 3: '3',
 4: '4',
 5: '5',
 6: '6',
 7: '7',
 8: '8',
 9: '9',
 10: '10',
 11: '11',
 12: '12',
 13: '13',
 14: '14',
 15: '15',
 16: '16',
 17: '17',
 18: '18',
 19: '19',
 20: '20',
 21: '21',
 22: '22',
 23: '23',
 24: '24',
 25: '25',
 26: '26',
 27: '27',
 28: '28',
 29: '29',
 30: '30',
 31: '31',
 32: '32',
 33: '33',
 34: '34',
 35: '35',
 36: '36',
 37: '37',
 38: '38',
 39: '39',
 40: '40',
 41: '41',
 42: '42',
 43: '43',
 44: '44',
 45: '45',
 46: '46',
 47: '47',
 48: '48',
 49: '49',
 50: '50',
 51: '51',
 52: '52',
 53: '53',
 54: '54',
 55: '55',
 56: '56',
 57: '57',
 58: '58',
 59: '59',
 60: '60',
 61: '61',
 62: '62',
 63: '63',
 64: '64',
 65: '65',
 66: '66',
 67: '67',
 68: '68',
 69: '69',
 70: '70',
 71: '71',
 72: '72',
 73: '73',
 74: '74',
 75: '75',
 76: '76',
 77: '77',
 78: '78',
 79: '79',
 80: '80',
 81: '81',
 82: '82',
 83: '83',
 84: '84',
 85: '85',
 86: '86',
 87: '87',
 88: '88',
 89: '89',
 90: '90',
 91: '91',
 92: '92