In [1]:
import os
import pickle
from collections import defaultdict 
from tqdm import tqdm

import numpy as np
import pykeen
from pykeen.pipeline import pipeline

In [2]:
dataset = 'FB15k-237'
betae_path = '/home/gebhart/projects/sheaf_kg/data/{}-betae'.format(dataset)

In [3]:
query_structures = [('e',('r',)), ('e', ('r', 'r')), ('e', ('r', 'r', 'r')), (('e', ('r',)), ('e', ('r',))), (('e', ('r',)), ('e', ('r',)), ('e', ('r',))), (('e', ('r', 'r')), ('e', ('r',))), ((('e', ('r',)), ('e', ('r',))), ('r',))]

query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }

In [4]:
ds = pykeen.datasets.get_dataset(dataset=dataset, dataset_kwargs=dict(create_inverse_triples=False))
training = ds.training.mapped_triples
relid2label = ds.training.relation_id_to_label 
label2relid = {v:k for k,v in relid2label.items()}

entid2label = ds.training.entity_id_to_label 
label2entid = {v:k for k,v in entid2label.items()}

You're trying to map triples with 30 entities and 0 relations that are not in the training set. These triples will be excluded from the mapping.
In total 28 from 20466 triples were filtered out


In [5]:
with open(os.path.join(betae_path,'test-queries.pkl'), 'rb') as f:
    test_queries = pickle.load(f)

with open(os.path.join(betae_path,'test-easy-answers.pkl'), 'rb') as f:
    test_answers = pickle.load(f)

with open(os.path.join(betae_path,'train-queries.pkl'), 'rb') as f:
    train_queries = pickle.load(f)

with open(os.path.join(betae_path,'train-answers.pkl'), 'rb') as f:
    train_answers = pickle.load(f)
    
with open(os.path.join(betae_path,'id2rel.pkl'), 'rb') as f:
    id2rel = pickle.load(f)
    
with open(os.path.join(betae_path,'id2ent.pkl'), 'rb') as f:
    id2ent = pickle.load(f)

In [6]:
def map_ent(e):
    return label2entid[id2ent[e]]

def map_rel(r):
    relname = id2rel[r]
    return label2relid[relname[1:]]

def orient_rel(r):
    orientation = 1
    relname = id2rel[r]
    if relname[0] == '-':
        orientation = -1
    return orientation

In [7]:
def L_p(queries):
    '''query of form ('e', ('r', 'r', ... , 'r')).
    here we assume 2 or more relations are present so 2p or greater
    '''
    all_ents = []
    all_rels = []
    all_invs = []
    for query in queries:
        all_ents.append(map_ent(query[0]))
        all_rels.append([map_rel(r) for r in query[1]])
        all_invs.append([orient_rel(r) for r in query[1]])
    return all_ents, all_rels, all_invs

def L_i(queries):
    '''query of form (('e', ('r',)), ('e', ('r',)), ... , ('e', ('r',)))'''
    all_ents = []
    all_rels = []
    all_invs = []
    for query in queries:
        all_ents.append([map_ent(pair[0]) for pair in query])
        all_rels.append([map_rel(pair[1][0]) for pair in query])
        all_invs.append([orient_rel(pair[1][0]) for pair in query])
    return all_ents, all_rels, all_invs

def L_ip(queries):
    '''query of form ((('e', ('r',)), ('e', ('r',))), ('r',))'''
    all_ents = []
    all_rels = []
    all_invs = []
    for query in queries:
        all_ents.append([map_ent(t[0]) for t in query[0]])
        all_rels.append([map_rel(query[0][0][1][0]), map_rel(query[0][1][1][0]), map_rel(query[1][0])])
        all_invs.append([orient_rel(query[0][0][1][0]), orient_rel(query[0][1][1][0]), orient_rel(query[1][0])])
    return all_ents, all_rels, all_invs

def L_pi(queries):
    '''query of form (('e', ('r', 'r')), ('e', ('r',)))'''
    all_ents = []
    all_rels = []
    all_invs = []
    for query in queries:
        all_ents.append([map_ent(t[0]) for t in query])
        all_rels.append([map_rel(query[0][1][0]), map_rel(query[0][1][1]), map_rel(query[1][1][0])])
        all_invs.append([orient_rel(query[0][1][0]), orient_rel(query[0][1][1]), orient_rel(query[1][1][0])])
    return all_ents, all_rels, all_invs
    
query_name_fn_dict = {'1p':L_p, '2p':L_p, '3p':L_p, '2i':L_i, '3i':L_i, 'ip':L_ip, 'pi':L_pi}

In [8]:
# train queries
for query_structure in query_structures:
    query_name = query_name_dict[query_structure]
    print('Query Structure: {}'.format(query_name))
    queries = train_queries[query_structure]
    all_answers = [[map_ent(a) for a in train_answers[query]] for query in queries]
    train_ents, train_rels, train_invs = query_name_fn_dict[query_name](queries)
    train_ents = np.array(train_ents)
    train_rels = np.array(train_rels)
    train_invs = np.array(train_invs)
    print(train_ents.shape, train_rels.shape, train_invs.shape)
    np.save(os.path.join(betae_path, '{}_remapped_train_entities'.format(query_name)), train_ents)
    np.save(os.path.join(betae_path, '{}_remapped_train_relations'.format(query_name)), train_rels)
    np.save(os.path.join(betae_path, '{}_remapped_train_inverses'.format(query_name)), train_invs)
    with open(os.path.join(betae_path, '{}_remapped_train_answers.pkl'.format(query_name)), 'wb') as f:
        pickle.dump(all_answers, f)

Query Structure: 1p
(149689,) (149689, 1) (149689, 1)
Query Structure: 2p
(149689,) (149689, 2) (149689, 2)
Query Structure: 3p
(149689,) (149689, 3) (149689, 3)
Query Structure: 2i
(149689, 2) (149689, 2) (149689, 2)
Query Structure: 3i
(149689, 3) (149689, 3) (149689, 3)
Query Structure: pi
(0,) (0,) (0,)
Query Structure: ip
(0,) (0,) (0,)


In [9]:
# test queries
for query_structure in query_structures:
    query_name = query_name_dict[query_structure]
    print('Query Structure: {}'.format(query_name))
    queries = [q for q in test_queries[query_structure] if len(test_answers[q]) > 0]
    all_answers = [[map_ent(a) for a in test_answers[query]] for query in queries]
    test_ents, test_rels, test_invs = query_name_fn_dict[query_name](queries)
    test_ents = np.array(test_ents)
    test_rels = np.array(test_rels)
    test_invs = np.array(test_invs)
    print(test_ents.shape, test_rels.shape, test_invs.shape)
    np.save(os.path.join(betae_path, '{}_remapped_test-easy_entities'.format(query_name)), test_ents)
    np.save(os.path.join(betae_path, '{}_remapped_test-easy_relations'.format(query_name)), test_rels)
    np.save(os.path.join(betae_path, '{}_remapped_test-easy_inverses'.format(query_name)), test_invs)
    with open(os.path.join(betae_path, '{}_remapped_test-easy_answers.pkl'.format(query_name)), 'wb') as f:
        pickle.dump(all_answers, f)

Query Structure: 1p
(16310,) (16310, 1) (16310, 1)
Query Structure: 2p
(4883,) (4883, 2) (4883, 2)
Query Structure: 3p
(4867,) (4867, 3) (4867, 3)
Query Structure: 2i
(4260, 2) (4260, 2) (4260, 2)
Query Structure: 3i
(3196, 3) (3196, 3) (3196, 3)
Query Structure: pi
(4285, 2) (4285, 3) (4285, 3)
Query Structure: ip
(4545, 2) (4545, 3) (4545, 3)


In [10]:
# did it work?
ents = np.load(os.path.join(betae_path, '{}_remapped_test-easy_entities.npy'.format('2p')))
rels = np.load(os.path.join(betae_path, '{}_remapped_test-easy_relations.npy'.format('2p')))
with open(os.path.join(betae_path, '{}_remapped_test-easy_answers.pkl'.format('2p')), 'rb') as f:
    ans = pickle.load(f)

In [11]:
ds.num_relations

237

In [12]:
ents.shape

(4883,)

In [13]:
rels

array([[132,  39],
       [ 32, 190],
       [  9,   3],
       ...,
       [ 23, 173],
       [204, 201],
       [183,  10]])

In [14]:
ents

array([ 7708,  6397, 12720, ..., 13980,  6596, 14351])