In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import time
import pickle
from tqdm import tqdm
os.chdir('/home/gebhart/projects/sheaf_kg')

import sheaf_kg.harmonic_extension as harmonic_extension

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pykeen
import torch
from pykeen.pipeline import pipeline

In [2]:
dataset = 'FB15k-237'
num_test = 100
model_name = 'StructuredEmbedding_betae_1000epochs_64dim_SoftplusLossloss_1234seed_20210203-1208'
save_loc = '/home/gebhart/projects/sheaf_kg/data/{}/{}/trained_model.pkl'.format(dataset,model_name)
result_loc = '/home/gebhart/projects/sheaf_kg/data/{}'.format(dataset)
q2b_path = '/home/gebhart/projects/sheaf_kg/data/{}-q2b'.format(dataset)
# q2b_path = '/home/gebhart/projects/sheaf_kg/data/{}-betae'.format(dataset)

In [3]:
model = torch.load(save_loc).to('cpu')

In [4]:
performance_df = pd.read_csv(os.path.join(result_loc, model_name+'.csv'))
performance_df

Unnamed: 0.1,Unnamed: 0,Side,Type,Metric,Value
0,0,both,avg,adjusted_mean_rank,0.092477
1,1,tail,avg,adjusted_mean_rank,0.092754
2,2,head,avg,adjusted_mean_rank,0.092199
3,3,both,avg,mean_rank,641.545259
4,4,both,avg,mean_reciprocal_rank,0.124593
5,5,both,avg,hits_at_1,0.075705
6,6,both,avg,hits_at_3,0.12774
7,7,both,avg,hits_at_5,0.159507
8,8,both,avg,hits_at_10,0.219004
9,9,both,best,mean_rank,641.544733


In [5]:
query_structures = [('e', ('r', 'r')), ('e', ('r', 'r', 'r')), (('e', ('r',)), ('e', ('r',))), (('e', ('r',)), ('e', ('r',)), ('e', ('r',))), (('e', ('r', 'r')), ('e', ('r',))), ((('e', ('r',)), ('e', ('r',))), ('r',))]

query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }

In [6]:
with open(os.path.join(q2b_path,'test-queries.pkl'), 'rb') as f:
    test_queries = pickle.load(f)

with open(os.path.join(q2b_path,'test-easy-answers.pkl'), 'rb') as f:
    test_answers = pickle.load(f)
    
with open(os.path.join(q2b_path,'id2rel.pkl'), 'rb') as f:
    id2rel = pickle.load(f)
    
with open(os.path.join(q2b_path,'id2ent.pkl'), 'rb') as f:
    id2ent = pickle.load(f)

In [7]:
def map_ent(e):
    return e

def map_rel(r):
    orientation = 1
    relname = id2rel[r]
    if relname[0] == '-':
        orientation = -1
    return r, orientation

In [8]:
def L_p(query, model):
    '''query of form ('e', ('r', 'r', ... , 'r')).
    here we assume 2 or more relations are present so 2p or greater
    '''
    ent = map_ent(query[0])
    invs = []
    rels = []
    for r in query[1]:
        mapped_id, orientation = map_rel(r) 
        rels.append(mapped_id)
        invs.append(orientation)
    n_path_ents = len(rels)
    
    edge_indices = np.concatenate([np.arange(0,n_path_ents)[:,np.newaxis].T, np.arange(1,n_path_ents+1)[:,np.newaxis].T], axis=0)
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = model.left_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = model.right_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
    restrictions[:,0,:,:] = left_restrictions
    restrictions[:,1,:,:] = right_restrictions
#     for invix in range(len(invs)):
#         if invix == -1:
#             restrictions[:,0,:,:] = right_restrictions
#             restrictions[:,1,:,:] = left_restrictions
#         else:
#             restrictions[:,0,:,:] = left_restrictions
#             restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor([ent])
    source_embeddings = model.entity_embeddings(indices=ent_idx_tensor).view(-1, model.embedding_dim).detach().numpy()
    
    B = np.array([0,n_path_ents],np.int)
    U = np.array(range(1,n_path_ents),np.int)
    source_vertices = [0]
    target_vertices = [1]
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings

def L_i(query, model):
    '''query of form (('e', ('r',)), ('e', ('r',)), ... , ('e', ('r',)))'''
    num_intersects = len(query)
    ents = []
    rels = []
    invs = []
    for pair in query:
        ents.append(map_ent(pair[0]))
        rel, inv = map_rel(pair[1][0])
        rels.append(rel)
        invs.append(inv)
    n_ents = num_intersects
    edge_indices = np.concatenate([np.arange(0,n_ents)[:,np.newaxis].T, np.full(n_ents,n_ents)[:,np.newaxis].T], axis=0)
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = model.left_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = model.right_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
    restrictions[:,0,:,:] = left_restrictions
    restrictions[:,1,:,:] = right_restrictions
#     for invix in range(len(invs)):
#         if invix == -1:
#             restrictions[:,0,:,:] = right_restrictions
#             restrictions[:,1,:,:] = left_restrictions
#         else:
#             restrictions[:,0,:,:] = left_restrictions
#             restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor(ents)
    source_embeddings = model.entity_embeddings(indices=ent_idx_tensor).view(-1, model.embedding_dim).detach().numpy()
    
    L = harmonic_extension.Laplacian(edge_indices, restrictions)
    source_vertices = np.arange(n_ents)
    target_vertices = [n_ents]
    return L, source_vertices, target_vertices, source_embeddings

def L_ip(query, model):
    '''query of form ((('e', ('r',)), ('e', ('r',))), ('r',))'''
    ents = [map_ent(t[0]) for t in query[0]]
    rel0, inv0 = map_rel(query[0][0][1][0])
    rel1, inv1 = map_rel(query[0][1][1][0])
    rel2, inv2 = map_rel(query[1][0])
    rels = [rel0, rel1, rel2]
    invs = [inv0, inv1, inv2]
    n_ents = len(ents)
    edge_indices = np.array([[0,2],[1,2],[2,3]],np.int).T
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = model.left_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = model.right_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
    restrictions[:,0,:,:] = left_restrictions
    restrictions[:,1,:,:] = right_restrictions
#     for invix in range(len(invs)):
#         if invix == -1:
#             restrictions[:,0,:,:] = right_restrictions
#             restrictions[:,1,:,:] = left_restrictions
#         else:
#             restrictions[:,0,:,:] = left_restrictions
#             restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor(ents)
    source_embeddings = model.entity_embeddings(indices=ent_idx_tensor).view(-1, model.embedding_dim).detach().numpy()
    
    B = np.array([0,2,3],np.int)
    U = np.array([1],np.int)
    source_vertices = [0,1]
    target_vertices = [2]
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings
    
def L_pi(query, model):
    '''query of form (('e', ('r', 'r')), ('e', ('r',)))'''
    ents = [map_ent(t[0]) for t in query]
    rel0, inv0 = map_rel(query[0][1][0])
    rel1, inv1 = map_rel(query[0][1][1])
    rel2, inv2 = map_rel(query[1][1][0])
    rels = [rel0, rel1, rel2]
    invs = [inv0, inv1, inv2]
    n_ents = len(ents)
    edge_indices = np.array([[0,2],[2,3],[1,3]],np.int).T
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = model.left_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = model.right_relation_embeddings(indices=rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
    restrictions[:,0,:,:] = left_restrictions
    restrictions[:,1,:,:] = right_restrictions
#     for invix in range(len(invs)):
#         if invix == -1:
#             restrictions[:,0,:,:] = right_restrictions
#             restrictions[:,1,:,:] = left_restrictions
#         else:
#             restrictions[:,0,:,:] = left_restrictions
#             restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor(ents)
    source_embeddings = model.entity_embeddings(indices=ent_idx_tensor).view(-1, model.embedding_dim).detach().numpy()
    
    B = np.array([0,1,3],np.int)
    U = np.array([2],np.int)
    source_vertices = [0,1]
    target_vertices = [2]
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings

query_name_fn_dict = {'2p': L_p, '3p':L_p, '2i': L_i, '3i':L_i, 'ip':L_ip, 'pi': L_pi}

In [9]:
target_embeddings = model.entity_embeddings(indices=None).view(-1, model.embedding_dim).detach().numpy().T
allhits1 = []
allhits3 = []
allhits5 = []
allhits10 = []
allmrr = []
query_names = []
for query_structure in query_structures:
    print('Running query : {}'.format(query_structure))
    query_name = query_name_dict[query_structure]
    query_names.append(query_name)
    fn = query_name_fn_dict[query_name]
    hits1 = 0.
    hits3 = 0.
    hits5 = 0.
    hits10 = 0.
    mrr = 0.
    cnt = 0
    queries = list(test_queries[query_structure])
    for query in tqdm(queries[:num_test]):
        if len(test_answers[query]) > 0:
            # we have a non-trivial "easy" query
            answers = [map_ent(a) for a in test_answers[query]]
            L, source_vertices, target_vertices, source_embeddings = fn(query, model)
            Q = harmonic_extension.compute_costs(L,source_vertices,target_vertices,source_embeddings.flatten(),target_embeddings,source_embeddings.shape[1])
            
            sortd = np.sort(Q)
            idxleft = np.searchsorted(sortd, Q[answers], side='left') + 1 # optimistic ranking
            idxright = np.searchsorted(sortd, Q[answers], side='right') + 1 # pessimistic ranking
            
            hits1 += ((np.mean(idxleft <= 1) + np.mean(idxright <= 1)) / 2.)
            hits3 += ((np.mean(idxleft <= 3) + np.mean(idxright <= 3)) / 2.)
            hits5 += ((np.mean(idxleft <= 5) + np.mean(idxright <= 5)) / 2.)
            hits10 += ((np.mean(idxleft <= 10) + np.mean(idxright <= 10)) / 2.)
            mrr += ((np.mean(1./idxleft) + np.mean(1./idxright)) / 2.)
            cnt += 1
    if cnt > 0:
        allhits1.append(hits1/cnt)
        allhits3.append(hits3/cnt)
        allhits5.append(hits5/cnt)
        allhits10.append(hits10/cnt)
        allmrr.append(mrr/cnt)
    else:
        default = 0.
        allhits1.append(default)
        allhits3.append(default)
        allhits5.append(default)
        allhits10.append(default)
        allmrr.append(default)

  2%|▏         | 2/100 [00:00<00:05, 17.79it/s]

Running query : ('e', ('r', 'r'))


100%|██████████| 100/100 [00:02<00:00, 38.92it/s]
  3%|▎         | 3/100 [00:00<00:03, 25.93it/s]

Running query : ('e', ('r', 'r', 'r'))


100%|██████████| 100/100 [00:03<00:00, 28.81it/s]
  7%|▋         | 7/100 [00:00<00:01, 63.83it/s]

Running query : (('e', ('r',)), ('e', ('r',)))


100%|██████████| 100/100 [00:01<00:00, 66.33it/s]
  7%|▋         | 7/100 [00:00<00:01, 67.67it/s]

Running query : (('e', ('r',)), ('e', ('r',)), ('e', ('r',)))


100%|██████████| 100/100 [00:01<00:00, 91.47it/s]
  4%|▍         | 4/100 [00:00<00:03, 30.85it/s]

Running query : (('e', ('r', 'r')), ('e', ('r',)))


100%|██████████| 100/100 [00:02<00:00, 36.51it/s]
  6%|▌         | 6/100 [00:00<00:01, 54.49it/s]

Running query : ((('e', ('r',)), ('e', ('r',))), ('r',))


100%|██████████| 100/100 [00:02<00:00, 35.91it/s]


In [10]:
cols = ['hits@1', 'hits@3', 'hits@5', 'hits@10', 'mrr']
df = pd.DataFrame(np.array([allhits1, allhits3, allhits5, allhits10, allmrr]).T, columns=cols, index=query_names) 

In [11]:
df * 100

Unnamed: 0,hits@1,hits@3,hits@5,hits@10,mrr
2p,0.0,0.0,0.009191,0.016557,0.05349
3p,0.003355,0.010552,0.030914,0.060451,0.068648
2i,0.0,0.0,0.00035,0.003455,0.042166
3i,0.002332,0.079103,0.099662,0.186467,0.085972
pi,0.000324,0.002065,0.010739,0.021858,0.047615
ip,0.004624,0.023318,0.038614,0.075164,0.06563
