In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import time
import pickle
from tqdm import tqdm
# for some reason, need to go to the sheaf_kg directory in order for torch.load to work
os.chdir('/home/gebhart/projects/sheaf_kg/sheaf_kg')

import sheaf_kg.batch_harmonic_extension as harmonic_extension
from sheaf_kg.sheafE_models import SheafE_Multisection, SheafE_Diag

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pykeen
import torch
from pykeen.pipeline import pipeline
from train_sheafE_betae import read_dataset, dataset_to_device, shuffle_datasets
import cvxpy as cp 
import torch
from cvxpylayers.torch import CvxpyLayer

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
dataset = 'FB15k-237'
use_section = 0
device = 'cpu'
train_test_queries = 'test'
model_name = 'SheafE_Complex_Queries_16embdim_16esdim_1sec_2norm_1epochs_SoftplusLoss_20210424-1401'
save_loc = '/home/gebhart/projects/sheaf_kg/data/{}/{}/trained_model.pkl'.format(dataset,model_name)
dataset_loc = '/home/gebhart/projects/sheaf_kg/data/{}-betae'.format(dataset)
model = torch.load(save_loc, map_location=device)

In [4]:
def cvxpy_problem(edge_index,dv,de,input_nodes):
    nv = torch.max(torch.max(edge_index)).item() + 1
    ne = edge_index.shape[1]
    x = cp.Variable(nv*dv)
    d = cp.Parameter((ne*de,nv*dv))
    xB = cp.Parameter(len(input_nodes)*dv)
    norm_constraints = [cp.norm(x[i*dv:(i+1)*dv]) <= 1 for i in range(nv) if i not in input_nodes]
    boundary_constraints = [x[v*dv:(v+1)*dv] == xB[i*dv:(i+1)*dv] for (i,v) in enumerate(input_nodes)]
    constraints = norm_constraints + boundary_constraints
    objective = cp.Minimize(cp.norm(d @ x))
    problem = cp.Problem(objective,constraints=constraints)

    layer = CvxpyLayer(problem, parameters=[d,xB], variables=[x])

    return layer

def coboundary(edge_index,restriction_maps):
    ne = edge_index.shape[1]
    nv = torch.max(torch.max(edge_index)).item() + 1 #assume there are vertices indexed 0...max
    de = restriction_maps.shape[2]
    dv = restriction_maps.shape[3]   
    d = torch.zeros((ne*de,nv*dv))
    for e in range(ne):
        h = edge_index[0,e]
        t = edge_index[1,e]
        d[e*de:(e+1)*de,h*dv:(h+1)*dv] = restriction_maps[e,0,:,:]
        d[e*de:(e+1)*de,t*dv:(t+1)*dv] = -restriction_maps[e,1,:,:]
    return d

def linear_chain(ne):
    edge_index = torch.zeros((2,ne),dtype=torch.int)
    for e in range(ne):
        edge_index[0,e] = e
        edge_index[1,e] = e + 1
    return edge_index

In [5]:
datasets = read_dataset(dataset_loc)
datasets = dataset_to_device(shuffle_datasets(datasets), device)

In [6]:
test_query_structures = ['2p', '3p']

query_structures = [('e', ('r', 'r')), ('e', ('r', 'r', 'r')), (('e', ('r',)), ('e', ('r',))), (('e', ('r',)), ('e', ('r',)), ('e', ('r',))), (('e', ('r', 'r')), ('e', ('r',))), ((('e', ('r',)), ('e', ('r',))), ('r',))]

query_name_dict = {('e',('r',)): '1p', 
                    ('e', ('r', 'r')): '2p',
                    ('e', ('r', 'r', 'r')): '3p',
                    (('e', ('r',)), ('e', ('r',))): '2i',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i',
                    ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip',
                    (('e', ('r', 'r')), ('e', ('r',))): 'pi',
                    (('e', ('r',)), ('e', ('r', 'n'))): '2in',
                    (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in',
                    ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp',
                    (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin',
                    (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni',
                    (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF',
                    ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM',
                    ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM'
                }

In [7]:
# ds = pykeen.datasets.get_dataset(dataset=dataset, dataset_kwargs=dict(create_inverse_triples=False))
# training = ds.training.mapped_triples
# 237*8*16*8*16/training.shape[0]

In [8]:
layer_2p = cvxpy_problem(linear_chain(2), model.embedding_dim, model.edge_stalk_dim, [0])
layer_3p = cvxpy_problem(linear_chain(3), model.embedding_dim, model.edge_stalk_dim, [0])

In [9]:
def L_2p(model, entities, relations, targets, invs=None, sec=0):
    all_ents = entities
    all_rels = relations
    all_invs = invs
    n_path_ents = all_rels.shape[1]
    num_queries = all_ents.shape[0]
    
    edge_indices = np.concatenate([np.arange(0,n_path_ents)[:,np.newaxis].T, np.arange(1,n_path_ents+1)[:,np.newaxis].T], axis=0)
    edge_indices = torch.LongTensor(np.repeat(edge_indices[np.newaxis, :, :], num_queries, axis=0))

    left_restrictions = torch.index_select(model.left_embeddings, 0, all_rels.flatten()).view(-1, all_rels.shape[1], model.edge_stalk_dim, model.embedding_dim)
    right_restrictions = torch.index_select(model.right_embeddings, 0, all_rels.flatten()).view(-1, all_rels.shape[1], model.edge_stalk_dim, model.embedding_dim)

    restrictions = torch.cat((left_restrictions.unsqueeze(2), right_restrictions.unsqueeze(2)), dim=2)
    if all_invs is not None:
        for ainvix in range(all_invs.shape[0]):
            invs = all_invs[ainvix]
            for invix in range(invs.shape[0]):
                if invs[invix] == -1:
                    tmp = torch.clone(restrictions[ainvix,invix,0,:,:])
                    restrictions[ainvix,invix,0,:,:] = restrictions[ainvix,invix,1,:,:]
                    restrictions[ainvix,invix,1,:,:] = tmp

    source_embeddings = torch.index_select(model.ent_embeddings, 0, all_ents)[:,:,sec]
    target_embeddings = torch.index_select(model.ent_embeddings, 0, targets)[:,:,sec]
    
    d = harmonic_extension.coboundary(edge_indices, restrictions)
    ret = torch.empty((num_queries, targets.shape[0]))
    for qix in range(num_queries):
        xopts, = layer_2p(d[qix], source_embeddings[qix].flatten())
        r = xopts.reshape((-1,source_embeddings.shape[1]))
        t = r[-1]
        ret[qix,:] = torch.linalg.norm(t[None,:] - target_embeddings, ord=2, dim=1)
#         ret[qix,:] = torch.matmul(target_embeddings, t)
    return ret

def L_3p(model, entities, relations, targets, invs=None, sec=0):
    all_ents = entities
    all_rels = relations
    all_invs = invs
    n_path_ents = all_rels.shape[1]
    num_queries = all_ents.shape[0]
    
    edge_indices = np.concatenate([np.arange(0,n_path_ents)[:,np.newaxis].T, np.arange(1,n_path_ents+1)[:,np.newaxis].T], axis=0)
    edge_indices = torch.LongTensor(np.repeat(edge_indices[np.newaxis, :, :], num_queries, axis=0))

    left_restrictions = torch.index_select(model.left_embeddings, 0, all_rels.flatten()).view(-1, all_rels.shape[1], model.edge_stalk_dim, model.embedding_dim)
    right_restrictions = torch.index_select(model.right_embeddings, 0, all_rels.flatten()).view(-1, all_rels.shape[1], model.edge_stalk_dim, model.embedding_dim)

    restrictions = torch.cat((left_restrictions.unsqueeze(2), right_restrictions.unsqueeze(2)), dim=2)
    if all_invs is not None:
        for ainvix in range(all_invs.shape[0]):
            invs = all_invs[ainvix]
            for invix in range(invs.shape[0]):
                if invs[invix] == -1:
                    tmp = torch.clone(restrictions[ainvix,invix,0,:,:])
                    restrictions[ainvix,invix,0,:,:] = restrictions[ainvix,invix,1,:,:]
                    restrictions[ainvix,invix,1,:,:] = tmp

    source_embeddings = torch.index_select(model.ent_embeddings, 0, all_ents)[:,:,sec]
    target_embeddings = torch.index_select(model.ent_embeddings, 0, targets)[:,:,sec]
    
    d = harmonic_extension.coboundary(edge_indices, restrictions)
    ret = torch.empty((num_queries, targets.shape[0]))
    for qix in range(num_queries):
        xopts, = layer_3p(d[qix], source_embeddings[qix].flatten())
        r = xopts.reshape((-1,source_embeddings.shape[1]))
        t = r[-1]
#         ret[qix,:] = torch.matmul(target_embeddings, t)
        ret[qix,:] = torch.linalg.norm(t[None,:] - target_embeddings, ord=2, dim=1)
    return ret

query_name_fn_dict = {'2p': L_2p, '3p':L_3p}#, '2i': L_i, '3i':L_i, 'ip':L_ip, 'pi': L_pi}

In [10]:
def test(model, test_data, model_inverses=False, sec=0, test_batch_size=5):
    with torch.no_grad():
        allhits1 = []
        allhits3 = []
        allhits5 = []
        allhits10 = []
        allmrr = []
        query_names = []
        for query_structure in test_query_structures:
            print('Running query : {}'.format(query_structure))
            hits1 = 0.
            hits3 = 0.
            hits5 = 0.
            hits10 = 0.
            mrr = 0.
            cnt = 0
            num_test = len(test_data[query_structure]['answers'])
            for qix in tqdm(range(0, num_test//2, test_batch_size)):
                if num_test - qix == 1:
                    continue
                entities = test_data[query_structure]['entities'][qix:qix+test_batch_size]
                relations = test_data[query_structure]['relations'][qix:qix+test_batch_size]
                if model_inverses:
                    inverses = None
                else:
                    inverses = test_data[query_structure]['inverses'][qix:qix+test_batch_size]
                all_answers = test_data[query_structure]['answers'][qix:qix+test_batch_size]
                targets = torch.arange(model.num_entities)
                Q = query_name_fn_dict[query_structure](model, entities, relations, targets, invs=inverses, sec=sec)
#                 Q = Q[:,:,sec]
                for i in range(Q.shape[0]):
                    Qi = Q[i].squeeze()
                    answers = all_answers[i]
                    sortd,_ = torch.sort(Qi)
                    idxleft = torch.searchsorted(sortd, Qi[answers], right=False) + 1
                    idxright = torch.searchsorted(sortd, Qi[answers], right=True) + 1
                    nl = idxleft.shape[0]
                    nr = idxright.shape[0]
                    # idxright = idxleft # throw this for optimistic ranking
                    hits1 += ((torch.sum(idxleft <= 1)/nl + torch.sum(idxright <= 1)/nr) / 2.)
                    hits3 += ((torch.sum(idxleft <= 3)/nl + torch.sum(idxright <= 3)/nr) / 2.)
                    hits5 += ((torch.sum(idxleft <= 5)/nl + torch.sum(idxright <= 5)/nr) / 2.)
                    hits10 += ((torch.sum(idxleft <= 10)/nl + torch.sum(idxright <= 10)/nr) / 2.)
                    mrr += ((torch.sum(1./idxleft)/nl + torch.sum(1./idxright)/nr) / 2.)
                    cnt += 1
            if cnt > 0:
                allhits1.append(hits1.item()/cnt)
                allhits3.append(hits3.item()/cnt)
                allhits5.append(hits5.item()/cnt)
                allhits10.append(hits10.item()/cnt)
                allmrr.append(mrr.item()/cnt)
            else:
                default = 0.
                allhits1.append(default)
                allhits3.append(default)
                allhits5.append(default)
                allhits10.append(default)
                allmrr.append(default)

        cols = ['hits@1', 'hits@3', 'hits@5', 'hits@10', 'mrr']
        df = pd.DataFrame(np.array([allhits1, allhits3, allhits5, allhits10, allmrr]).T, columns=cols, index=test_query_structures)
        return df

In [11]:
df = test(model, datasets['test-easy'], model_inverses=True)

  0%|          | 2/489 [00:00<00:34, 13.94it/s]

Running query : 2p


100%|██████████| 489/489 [00:25<00:00, 19.31it/s]
  0%|          | 1/487 [00:00<01:18,  6.20it/s]

Running query : 3p


100%|██████████| 487/487 [00:36<00:00, 13.53it/s]


In [12]:
df.head()*100

Unnamed: 0,hits@1,hits@3,hits@5,hits@10,mrr
2p,0.028667,0.14457,0.197666,0.298292,0.173938
3p,0.04952,0.174237,0.250832,0.352918,0.205781
