In [1]:
import os
os.chdir('../sheaf_kg')
import sys
import time
import argparse
import json
from tqdm import tqdm

import pandas as pd
import numpy as np
import pykeen
from pykeen.evaluation import rank_based_evaluator
import torch
from scipy import stats

from sheaf_kg.train_sheafE_betae import read_dataset, shuffle_datasets, dataset_to_device, sample_answers
import sheaf_kg.complex_functions as cf
# from sheaf_kg.complex_functions import test_batch

PyKEEN was unable to load dataset openbiolinkf1. Try uninstalling PyKEEN with ``pip uninstall pykeen`` then reinstalling
PyKEEN was unable to load dataset openbiolinkf2. Try uninstalling PyKEEN with ``pip uninstall pykeen`` then reinstalling


In [2]:
dataset = 'FB15k-237'
savenames = ['SheafE_Translational_64embdim_64esdim_1sec_2norm_0.0lbda_250epochs_MarginRankingLossloss_None_11seed_20210602-2005',
             'SheafE_Translational_64embdim_64esdim_1sec_2norm_0.0lbda_250epochs_MarginRankingLossloss_None_33seed_20210602-2046',
             'SheafE_Translational_64embdim_64esdim_1sec_2norm_0.0lbda_250epochs_MarginRankingLossloss_None_44seed_20210602-2109',
            ]
savedir = f'/home/gebhart/projects/sheaf_kg/data/{dataset}'
complex_dataset_loc = f'/home/gebhart/projects/sheaf_kg/data/{dataset}-betae'

test_query_structures = ['1p','2p','3p','2i','3i','ip','pi']


In [3]:
fun_map = {'1p':cf.L_p_traversal_transE,
           '2p':cf.L_p_traversal_transE,
           '3p':cf.L_p_traversal_transE,
           '2i':cf.L_i_traversal_transE,
           '3i':cf.L_i_traversal_transE,
           'ip':cf.L_ip_traversal_transE,
           'pi':cf.L_pi_traversal_transE
          }

def test_batch(model, test_data, model_inverses=False, test_batch_size=5,
                test_query_structures=test_query_structures,
                ks=[1,3,5,10]):
    with torch.no_grad():
        results = []
        for query_structure in test_query_structures:
            print('Running query : {}'.format(query_structure))
            all_avg_ranks = []
            cnt = 0
            num_test = len(test_data[query_structure]['answers'])
            for qix in tqdm(range(0, num_test, test_batch_size)):
                if num_test - qix == 1:
                    continue
                entities = test_data[query_structure]['entities'][qix:qix+test_batch_size]
                relations = test_data[query_structure]['relations'][qix:qix+test_batch_size]
                if model_inverses:
                    inverses = None
                else:
                    inverses = test_data[query_structure]['inverses'][qix:qix+test_batch_size]
                all_answers = test_data[query_structure]['answers'][qix:qix+test_batch_size]
                targets = torch.arange(model.num_entities).to(model.device)
                Q = fun_map[query_structure](model, entities, relations, targets, invs=inverses, p=2, variety='sheafE')
                answer_lens = np.array([len(a) for a in all_answers])
                max_len = answer_lens.max()
                for l in np.unique(answer_lens):
                    idxs = np.where(answer_lens == l)[0]
                    answers = [all_answers[j] for j in idxs]
                    filter_fix = torch.arange(l)
                    for aix in range(len(idxs)):
                        ranks = rank_based_evaluator.compute_rank_from_scores(Q[idxs[aix],answers[aix]].unsqueeze(1), Q[idxs[aix],:].unsqueeze(0))
                        avg_rank = (torch.sort(ranks['realistic'].cpu(), dim=0)[0] - filter_fix).numpy()
                        all_avg_ranks.append(avg_rank)
            all_avg_ranks = np.concatenate(all_avg_ranks)
            rd = {k: np.mean(all_avg_ranks <= k) for k in ks}
            mrr = np.reciprocal(stats.hmean(all_avg_ranks))
            rd['mrr'] = mrr if isinstance(mrr, float) else mrr[0]
            # rd['mr'] = np.mean(all_avg_ranks)
            results.append(rd)

        df = pd.DataFrame(results, index=test_query_structures)
        return df

In [4]:
datasets = read_dataset(complex_dataset_loc)
datasets = dataset_to_device(shuffle_datasets(datasets), 'cuda')

In [5]:
results = []
for savename in savenames:
    model = torch.load(os.path.join(savedir, savename, 'trained_model.pkl'))
    extension_df = test_batch(model, datasets['test-easy'], model_inverses=True, test_query_structures=test_query_structures)
    results.append(extension_df)    

  2%|▏         | 49/3262 [00:00<00:06, 486.96it/s]

Running query : 1p


100%|██████████| 3262/3262 [00:06<00:00, 507.69it/s]
  2%|▏         | 24/977 [00:00<00:04, 234.40it/s]

Running query : 2p


100%|██████████| 977/977 [00:04<00:00, 212.32it/s]
  2%|▏         | 16/974 [00:00<00:06, 157.82it/s]

Running query : 3p


100%|██████████| 974/974 [00:06<00:00, 150.33it/s]
  5%|▍         | 42/852 [00:00<00:01, 416.66it/s]

Running query : 2i


100%|██████████| 852/852 [00:02<00:00, 416.32it/s]
  7%|▋         | 43/640 [00:00<00:01, 423.23it/s]

Running query : 3i


100%|██████████| 640/640 [00:01<00:00, 423.90it/s]
  2%|▏         | 22/909 [00:00<00:04, 215.88it/s]

Running query : ip


100%|██████████| 909/909 [00:03<00:00, 227.46it/s]
  4%|▍         | 35/857 [00:00<00:02, 341.55it/s]

Running query : pi


100%|██████████| 857/857 [00:02<00:00, 337.38it/s]
  2%|▏         | 52/3262 [00:00<00:06, 516.70it/s]

Running query : 1p


100%|██████████| 3262/3262 [00:06<00:00, 511.88it/s]
  2%|▏         | 24/977 [00:00<00:03, 238.32it/s]

Running query : 2p


100%|██████████| 977/977 [00:04<00:00, 212.56it/s]
  2%|▏         | 16/974 [00:00<00:06, 156.59it/s]

Running query : 3p


100%|██████████| 974/974 [00:06<00:00, 150.03it/s]
  5%|▍         | 42/852 [00:00<00:01, 418.40it/s]

Running query : 2i


100%|██████████| 852/852 [00:02<00:00, 419.96it/s]
  7%|▋         | 43/640 [00:00<00:01, 425.23it/s]

Running query : 3i


100%|██████████| 640/640 [00:01<00:00, 426.09it/s]
  2%|▏         | 22/909 [00:00<00:04, 215.84it/s]

Running query : ip


100%|██████████| 909/909 [00:03<00:00, 227.85it/s]
  4%|▍         | 34/857 [00:00<00:02, 339.92it/s]

Running query : pi


100%|██████████| 857/857 [00:02<00:00, 335.08it/s]
  2%|▏         | 52/3262 [00:00<00:06, 512.86it/s]

Running query : 1p


100%|██████████| 3262/3262 [00:06<00:00, 512.91it/s]
  2%|▏         | 24/977 [00:00<00:04, 236.58it/s]

Running query : 2p


100%|██████████| 977/977 [00:04<00:00, 213.11it/s]
  2%|▏         | 16/974 [00:00<00:06, 158.62it/s]

Running query : 3p


100%|██████████| 974/974 [00:06<00:00, 149.86it/s]
  5%|▍         | 42/852 [00:00<00:01, 417.29it/s]

Running query : 2i


100%|██████████| 852/852 [00:02<00:00, 417.97it/s]
  7%|▋         | 43/640 [00:00<00:01, 423.46it/s]

Running query : 3i


100%|██████████| 640/640 [00:01<00:00, 425.74it/s]
  2%|▏         | 22/909 [00:00<00:04, 214.22it/s]

Running query : ip


100%|██████████| 909/909 [00:03<00:00, 227.44it/s]
  4%|▍         | 35/857 [00:00<00:02, 337.91it/s]

Running query : pi


100%|██████████| 857/857 [00:02<00:00, 334.88it/s]


In [6]:
mndf = sum(results)/len(results)
mndf*100

Unnamed: 0,1,3,5,10,mrr
1p,4.757007,11.570886,15.572693,21.46436,10.167606
2p,0.6393,1.372207,1.906338,2.850881,1.481599
3p,0.277892,0.614917,0.870591,1.383304,0.752916
2i,0.462078,1.428388,2.32651,4.323176,1.895931
3i,0.324724,1.043929,1.819205,3.933996,1.79743
ip,0.445792,0.955036,1.474849,2.747254,1.287176
pi,0.225104,0.698068,1.189696,2.648452,1.289039


In [7]:
print((mndf.transpose().round(4)*100).to_latex())

\begin{tabular}{lrrrrrrr}
\toprule
{} &     1p &    2p &    3p &    2i &    3i &    ip &    pi \\
\midrule
1   &   4.76 &  0.64 &  0.28 &  0.46 &  0.32 &  0.45 &  0.23 \\
3   &  11.57 &  1.37 &  0.61 &  1.43 &  1.04 &  0.96 &  0.70 \\
5   &  15.57 &  1.91 &  0.87 &  2.33 &  1.82 &  1.47 &  1.19 \\
10  &  21.46 &  2.85 &  1.38 &  4.32 &  3.93 &  2.75 &  2.65 \\
mrr &  10.17 &  1.48 &  0.75 &  1.90 &  1.80 &  1.29 &  1.29 \\
\bottomrule
\end{tabular}

