In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import time
import pickle
from tqdm import tqdm
os.chdir('/home/gebhart/projects/sheaf_kg')

import sheaf_kg.harmonic_extension as harmonic_extension

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pykeen
import torch
from pykeen.pipeline import pipeline

In [2]:
dataset = 'FB15k-237'
num_test = 1000
model_name = 'SheafE_multisection_50_sections_1000epochs_64dim_SoftplusLossloss_1235seed_20210202-0825'
save_loc = '/home/gebhart/projects/sheaf_kg/data/{}/{}/trained_model.pkl'.format(dataset,model_name)
# q2b_path = '/home/gebhart/projects/sheaf_kg/data/{}-q2b'.format(dataset)
q2b_path = '/home/gebhart/projects/sheaf_kg/data/{}-betae'.format(dataset)

In [3]:
query_structures = [('e', ('r', 'r')), ('e', ('r', 'r', 'r')), (('e', ('r',)), ('e', ('r',))), (('e', ('r',)), ('e', ('r',)), ('e', ('r',))), (('e', ('r', 'r')), ('e', ('r',))), ((('e', ('r',)), ('e', ('r',))), ('r',))]

query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }

In [4]:
from typing import Optional

from pykeen.models import StructuredEmbedding
from pykeen.models.base import _OldAbstractModel
from pykeen.nn import Embedding
from pykeen.losses import Loss
from pykeen.nn.init import xavier_uniform_
from pykeen.regularizers import Regularizer
from pykeen.triples import TriplesFactory
from pykeen.typing import DeviceHint
from pykeen.utils import compose

from torch.nn import functional
from torch.nn.parameter import Parameter

class ModifiedSE(_OldAbstractModel):

    def __init__(
        self,
        triples_factory: TriplesFactory,
        embedding_dim: int = 64,
        scoring_fct_norm: int = 2,
        num_sections: int = 50,
        loss: Optional[Loss] = None,
        preferred_device: DeviceHint = None,
        random_seed: Optional[int] = None,
        regularizer: Optional[Regularizer] = None,
    ) -> None:
        r"""Initialize SE.

        :param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$.
        :param scoring_fct_norm: The $l_p$ norm. Usually 1 for SE.
        """
        super().__init__(
            triples_factory=triples_factory,
            loss=loss,
            preferred_device=preferred_device,
            random_seed=random_seed,
            regularizer=regularizer,
        )

        self.embedding_dim = embedding_dim
        self.num_sections = num_sections
        self.scoring_fct_norm = scoring_fct_norm

        esize = (triples_factory.num_entities, num_sections, embedding_dim)
        self.ent_embeddings = Parameter(nn.init.xavier_uniform_(torch.empty(esize, device=preferred_device, dtype=torch.float32)),requires_grad=True)

        tsize = (triples_factory.num_relations, embedding_dim, embedding_dim)
        self.left_embeddings = Parameter(nn.init.xavier_uniform_(torch.empty(tsize, device=preferred_device, dtype=torch.float32)),requires_grad=True)
        self.right_embeddings = Parameter(nn.init.xavier_uniform_(torch.empty(tsize, device=preferred_device, requires_grad=True, dtype=torch.float32)),requires_grad=True)

    def _reset_parameters_(self):  # noqa: D102
        self.ent_embeddings = nn.init.xavier_uniform_(self.ent_embeddings)
        self.left_embeddings = nn.init.xavier_uniform_(self.left_embeddings)
        self.right_embeddings = nn.init.xavier_uniform_(self.right_embeddings)


    def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
        # Get embeddings
        h = torch.index_select(self.ent_embeddings, 0, hrt_batch[:, 0]).view(-1, self.embedding_dim, self.num_sections)
        rel_h = torch.index_select(self.left_embeddings, 0, hrt_batch[:, 1])
        rel_t = torch.index_select(self.right_embeddings, 0, hrt_batch[:, 1])
        t = torch.index_select(self.ent_embeddings, 0, hrt_batch[:, 2]).view(-1, self.embedding_dim, self.num_sections)

        # Project entities
        proj_h = rel_h @ h
        proj_t = rel_t @ t
        scores = -torch.norm(proj_h - proj_t, dim=(1,2), p=self.scoring_fct_norm)
        return scores


    def score_t(self, hr_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor:  # noqa: D102
        # Get embeddings
        h = torch.index_select(self.ent_embeddings, 0, hr_batch[:, 0]).view(-1, self.embedding_dim, self.num_sections)
        rel_h = torch.index_select(self.left_embeddings, 0, hr_batch[:, 1])
        rel_t = torch.index_select(self.right_embeddings, 0, hr_batch[:, 1])
        rel_t = rel_t.view(-1, 1, self.embedding_dim, self.embedding_dim)
        t_all = self.ent_embeddings.view(1, -1, self.embedding_dim, self.num_sections)

        if slice_size is not None:
            raise ValueError('Not implemented')

        else:
            # Project entities
            proj_h = rel_h @ h
            proj_t = rel_t @ t_all

        scores = -torch.norm(proj_h[:, None, :, :] - proj_t[:, :, :, :], dim=(-1,-2), p=self.scoring_fct_norm)

        return scores


    def score_h(self, rt_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor:  # noqa: D102
        # Get embeddings
        h_all = self.ent_embeddings.view(1, -1, self.embedding_dim, self.num_sections)
        rel_h = torch.index_select(self.left_embeddings, 0, rt_batch[:, 0])
        rel_h = rel_h.view(-1, 1, self.embedding_dim, self.embedding_dim)
        rel_t = torch.index_select(self.right_embeddings, 0, rt_batch[:, 0])
        t = torch.index_select(self.ent_embeddings, 0, rt_batch[:, 1]).view(-1, self.embedding_dim, self.num_sections)

        if slice_size is not None:
            raise ValueError('Not implemented')
        else:
            # Project entities
            proj_h = rel_h @ h_all
            proj_t = rel_t @ t

        scores = -torch.norm(proj_h[:, :, :, :] - proj_t[:, None, :, :], dim=(-1,-2), p=self.scoring_fct_norm)

        return scores

In [5]:
model = torch.load(save_loc).to('cpu')

In [6]:
ds = pykeen.datasets.get_dataset(dataset=dataset)
training = ds.training.mapped_triples
relid2label = ds.training.relation_id_to_label 
label2relid = {v:k for k,v in relid2label.items()}

entid2label = ds.training.entity_id_to_label 
label2entid = {v:k for k,v in entid2label.items()}

You're trying to map triples with 30 entities and 0 relations that are not in the training set. These triples will be excluded from the mapping.
In total 28 from 20466 triples were filtered out


In [7]:
with open(os.path.join(q2b_path,'test-queries.pkl'), 'rb') as f:
    test_queries = pickle.load(f)

with open(os.path.join(q2b_path,'test-easy-answers.pkl'), 'rb') as f:
    test_answers = pickle.load(f)
    
with open(os.path.join(q2b_path,'id2rel.pkl'), 'rb') as f:
    id2rel = pickle.load(f)
    
with open(os.path.join(q2b_path,'id2ent.pkl'), 'rb') as f:
    id2ent = pickle.load(f)

In [8]:
def map_ent(e):
    return label2entid[id2ent[e]]
def map_rel(r):
    orientation = 1
    relname = id2rel[r]
    if relname[0] == '-':
        orientation = -1
    return label2relid[relname[1:]], orientation

In [9]:
section = 25
def L_p(query, model):
    '''query of form ('e', ('r', 'r', ... , 'r')).
    here we assume 2 or more relations are present so 2p or greater
    '''
    ent = map_ent(query[0])
    invs = []
    rels = []
    for r in query[1]:
        mapped_id, orientation = map_rel(r) 
        rels.append(mapped_id)
        invs.append(orientation)
    n_path_ents = len(rels)
    
    edge_indices = np.concatenate([np.arange(0,n_path_ents)[:,np.newaxis].T, np.arange(1,n_path_ents+1)[:,np.newaxis].T], axis=0)
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
#     restrictions[:,0,:,:] = left_restrictions
#     restrictions[:,1,:,:] = right_restrictions
    for invix in range(len(invs)):
        if invix == -1:
            restrictions[:,0,:,:] = right_restrictions
            restrictions[:,1,:,:] = left_restrictions
        else:
            restrictions[:,0,:,:] = left_restrictions
            restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor([ent])
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor)
#     source_embeddings = source_embeddings[:,section,:].view(-1, model.embedding_dim).detach().numpy()
    source_embeddings = torch.mean(source_embeddings, dim=1).view(-1, model.embedding_dim).detach().numpy()
    
    B = np.array([0,n_path_ents],np.int)
    U = np.array(range(1,n_path_ents),np.int)
    source_vertices = [0]
    target_vertices = [1]
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings

def L_i(query, model):
    '''query of form (('e', ('r',)), ('e', ('r',)), ... , ('e', ('r',)))'''
    num_intersects = len(query)
    ents = []
    rels = []
    invs = []
    for pair in query:
        ents.append(map_ent(pair[0]))
        rel, inv = map_rel(pair[1][0])
        rels.append(rel)
        invs.append(inv)
    n_ents = num_intersects
    edge_indices = np.concatenate([np.arange(0,n_ents)[:,np.newaxis].T, np.full(n_ents,n_ents)[:,np.newaxis].T], axis=0)
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()

    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
#     restrictions[:,0,:,:] = left_restrictions
#     restrictions[:,1,:,:] = right_restrictions
    for invix in range(len(invs)):
        if invix == -1:
            restrictions[:,0,:,:] = right_restrictions
            restrictions[:,1,:,:] = left_restrictions
        else:
            restrictions[:,0,:,:] = left_restrictions
            restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor(ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor)
#     source_embeddings = source_embeddings[:,section,:].view(-1, model.embedding_dim).detach().numpy()
    source_embeddings = torch.mean(source_embeddings, dim=1).view(-1, model.embedding_dim).detach().numpy()
    
    L = harmonic_extension.Laplacian(edge_indices, restrictions)
    source_vertices = np.arange(n_ents)
    target_vertices = [n_ents]
    return L, source_vertices, target_vertices, source_embeddings

def L_ip(query, model):
    '''query of form ((('e', ('r',)), ('e', ('r',))), ('r',))'''
    ents = [map_ent(t[0]) for t in query[0]]
    rel0, inv0 = map_rel(query[0][0][1][0])
    rel1, inv1 = map_rel(query[0][1][1][0])
    rel2, inv2 = map_rel(query[1][0])
    rels = [rel0, rel1, rel2]
    invs = [inv0, inv1, inv2]
    n_ents = len(ents)
    edge_indices = np.array([[0,2],[1,2],[2,3]],np.int).T
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
#     restrictions[:,0,:,:] = left_restrictions
#     restrictions[:,1,:,:] = right_restrictions
    for invix in range(len(invs)):
        if invix == -1:
            restrictions[:,0,:,:] = right_restrictions
            restrictions[:,1,:,:] = left_restrictions
        else:
            restrictions[:,0,:,:] = left_restrictions
            restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor(ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor)
#     source_embeddings = source_embeddings[:,section,:].view(-1, model.embedding_dim).detach().numpy()
    source_embeddings = torch.mean(source_embeddings, dim=1).view(-1, model.embedding_dim).detach().numpy()
    
    B = np.array([0,2,3],np.int)
    U = np.array([1],np.int)
    source_vertices = [0,1]
    target_vertices = [2]
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings
    
def L_pi(query, model):
    '''query of form (('e', ('r', 'r')), ('e', ('r',)))'''
    ents = [map_ent(t[0]) for t in query]
    rel0, inv0 = map_rel(query[0][1][0])
    rel1, inv1 = map_rel(query[0][1][1])
    rel2, inv2 = map_rel(query[1][1][0])
    rels = [rel0, rel1, rel2]
    invs = [inv0, inv1, inv2]
    n_ents = len(ents)
    edge_indices = np.array([[0,2],[2,3],[1,3]],np.int).T
    
    rel_idx_tensor = torch.LongTensor(rels)
    left_restrictions = torch.index_select(model.left_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    right_restrictions = torch.index_select(model.right_embeddings, 0, rel_idx_tensor).view(-1, model.embedding_dim, model.embedding_dim).detach().numpy()
    
    restrictions = np.empty((len(rels), 2, left_restrictions.shape[2], left_restrictions.shape[1]))
#     restrictions[:,0,:,:] = left_restrictions
#     restrictions[:,1,:,:] = right_restrictions
    for invix in range(len(invs)):
        if invix == -1:
            restrictions[:,0,:,:] = right_restrictions
            restrictions[:,1,:,:] = left_restrictions
        else:
            restrictions[:,0,:,:] = left_restrictions
            restrictions[:,1,:,:] = right_restrictions
    
    ent_idx_tensor = torch.LongTensor(ents)
    source_embeddings = torch.index_select(model.ent_embeddings, 0, ent_idx_tensor)
#     source_embeddings = source_embeddings[:,section,:].view(-1, model.embedding_dim).detach().numpy()
    source_embeddings = torch.mean(source_embeddings, dim=1).view(-1, model.embedding_dim).detach().numpy()
    
    B = np.array([0,1,3],np.int)
    U = np.array([2],np.int)
    source_vertices = [0,1]
    target_vertices = [2]
    LSchur = harmonic_extension.Kron_reduction(edge_indices, restrictions, B, U)
    return LSchur, source_vertices, target_vertices, source_embeddings

query_name_fn_dict = {'2p': L_p, '3p':L_p, '2i': L_i, '3i':L_i, 'ip':L_ip, 'pi': L_pi}

In [10]:
# target_embeddings = model.ent_embeddings[:,section,:].view(-1, model.embedding_dim).detach().numpy().T
target_embeddings = torch.mean(model.ent_embeddings, dim=1).view(-1, model.embedding_dim).detach().numpy().T
allhits1 = []
allhits3 = []
allhits5 = []
allhits10 = []
allmrr = []
query_names = []
for query_structure in query_structures:
    print('Running query : {}'.format(query_structure))
    query_name = query_name_dict[query_structure]
    query_names.append(query_name)
    fn = query_name_fn_dict[query_name]
    hits1 = 0.
    hits3 = 0.
    hits5 = 0.
    hits10 = 0.
    mrr = 0.
    cnt = 0
    queries = list(test_queries[query_structure])
    for query in tqdm(queries[:num_test]):
        if len(test_answers[query]) > 0:
            # we have a non-trivial "easy" query
            answers = [map_ent(a) for a in test_answers[query]]
            L, source_vertices, target_vertices, source_embeddings = fn(query, model)
            Q = harmonic_extension.compute_costs(L,source_vertices,target_vertices,source_embeddings.flatten(),target_embeddings,source_embeddings.shape[1])
            sortd = np.sort(Q)
            idxleft = np.searchsorted(sortd, Q[answers], side='left') + 1
            idxright = np.searchsorted(sortd, Q[answers], side='right') + 1
#             idxright = idxleft # throw this for optimistic ranking
            hits1 += ((np.mean(idxleft <= 1) + np.mean(idxright <= 1)) / 2.)
            hits3 += ((np.mean(idxleft <= 3) + np.mean(idxright <= 3)) / 2.)
            hits5 += ((np.mean(idxleft <= 5) + np.mean(idxright <= 5)) / 2.)
            hits10 += ((np.mean(idxleft <= 10) + np.mean(idxright <= 10)) / 2.)
            mrr += ((np.mean(1./idxleft) + np.mean(1./idxright)) / 2.)
            cnt += 1
    if cnt > 0:
        allhits1.append(hits1/cnt)
        allhits3.append(hits3/cnt)
        allhits5.append(hits5/cnt)
        allhits10.append(hits10/cnt)
        allmrr.append(mrr/cnt)
    else:
        default = 0.
        allhits1.append(default)
        allhits3.append(default)
        allhits5.append(default)
        allhits10.append(default)
        allmrr.append(default)

  0%|          | 2/1000 [00:00<01:04, 15.51it/s]

Running query : ('e', ('r', 'r'))


100%|██████████| 1000/1000 [00:35<00:00, 28.31it/s]
  0%|          | 3/1000 [00:00<00:34, 29.19it/s]

Running query : ('e', ('r', 'r', 'r'))


100%|██████████| 1000/1000 [00:47<00:00, 21.25it/s]
  1%|          | 6/1000 [00:00<00:17, 55.76it/s]

Running query : (('e', ('r',)), ('e', ('r',)))


100%|██████████| 1000/1000 [00:15<00:00, 63.36it/s]
  1%|          | 7/1000 [00:00<00:15, 65.07it/s]

Running query : (('e', ('r',)), ('e', ('r',)), ('e', ('r',)))


100%|██████████| 1000/1000 [00:12<00:00, 81.78it/s]
  0%|          | 5/1000 [00:00<00:21, 45.28it/s]

Running query : (('e', ('r', 'r')), ('e', ('r',)))


100%|██████████| 1000/1000 [00:30<00:00, 32.53it/s]
  0%|          | 5/1000 [00:00<00:28, 34.47it/s]

Running query : ((('e', ('r',)), ('e', ('r',))), ('r',))


100%|██████████| 1000/1000 [00:33<00:00, 29.87it/s]


In [11]:
cols = ['hits@1', 'hits@3', 'hits@5', 'hits@10', 'mrr']
df = pd.DataFrame(np.array([allhits1, allhits3, allhits5, allhits10, allmrr]).T,columns=cols, index=query_names) 

In [12]:
df * 100

Unnamed: 0,hits@1,hits@3,hits@5,hits@10,mrr
2p,0.003393,0.02518,0.042993,0.098226,0.073679
3p,0.000759,0.01017,0.01811,0.051208,0.061706
2i,0.004543,0.030328,0.04677,0.108958,0.093733
3i,0.005219,0.042664,0.065942,0.144193,0.104747
pi,0.003386,0.01303,0.024361,0.060081,0.072847
ip,0.003193,0.019107,0.040064,0.080648,0.080585
