In [1]:
import os
import sys
import time
import argparse
import json
from tqdm import tqdm

import pandas as pd
import numpy as np
import pykeen
from pykeen.evaluation import rank_based_evaluator
import torch
from scipy import stats

from sheaf_kg.train_sheafE_betae import read_dataset, shuffle_datasets, dataset_to_device, sample_answers
import sheaf_kg.complex_functions as cf
# from sheaf_kg.complex_functions import test_batch

PyKEEN was unable to load dataset openbiolinkf1. Try uninstalling PyKEEN with ``pip uninstall pykeen`` then reinstalling
PyKEEN was unable to load dataset openbiolinkf2. Try uninstalling PyKEEN with ``pip uninstall pykeen`` then reinstalling


In [2]:
dataset = 'FB15k-237'
savename = 'TransE_1000epochs_64embdim_SoftplusLossloss_1234seed_20210519-1452'

savedir = f'/home/gebhart/projects/sheaf_kg/data/{dataset}'
complex_dataset_loc = f'/home/gebhart/projects/sheaf_kg/data/{dataset}-betae'

test_query_structures = ['1p','2p','3p','2i','3i','ip','pi']

In [3]:
model = torch.load(os.path.join(savedir, savename, 'trained_model.pkl'))

In [4]:
fun_map = {'1p':cf.L_p_traversal_transE,
           '2p':cf.L_p_traversal_transE,
           '3p':cf.L_p_traversal_transE,
           '2i':cf.L_i_traversal_transE,
           '3i':cf.L_i_traversal_transE,
           'ip':cf.L_ip_traversal_transE,
           'pi':cf.L_pi_traversal_transE
          }

def test_batch(model, test_data, model_inverses=False, test_batch_size=5,
                test_query_structures=test_query_structures,
                ks=[1,3,5,10]):
    with torch.no_grad():
        results = []
        for query_structure in test_query_structures:
            print('Running query : {}'.format(query_structure))
            all_avg_ranks = []
            cnt = 0
            num_test = len(test_data[query_structure]['answers'])
            for qix in tqdm(range(0, num_test//2, test_batch_size)):
                if num_test - qix == 1:
                    continue
                entities = test_data[query_structure]['entities'][qix:qix+test_batch_size]
                relations = test_data[query_structure]['relations'][qix:qix+test_batch_size]
                if model_inverses:
                    inverses = None
                else:
                    inverses = test_data[query_structure]['inverses'][qix:qix+test_batch_size]
                all_answers = test_data[query_structure]['answers'][qix:qix+test_batch_size]
                targets = torch.arange(model.entity_embeddings(indices=None).shape[0]).to(model.device)
                Q = fun_map[query_structure](model, entities, relations, targets, invs=inverses, p=1)
                
                max_len = len(max(all_answers, key=len))
                for i in range(max_len):
                    answers = [a[i] if len(a) > i else a[-1] for a in all_answers ]
                    if len(answers) > 0:
                        ranks = rank_based_evaluator.compute_rank_from_scores(Q[np.vstack((np.arange(len(answers)), answers))].unsqueeze(1), Q)
                        avg_rank = ranks['realistic'].cpu().numpy()
                        all_avg_ranks.append(avg_rank)
            all_avg_ranks = np.concatenate(all_avg_ranks)
            rd = {k: np.mean(all_avg_ranks <= k) for k in ks}
            mrr = np.reciprocal(stats.hmean(all_avg_ranks))
            rd['mrr'] = mrr if isinstance(mrr, float) else mrr[0]
            # rd['mr'] = np.mean(all_avg_ranks)
            results.append(rd)

        df = pd.DataFrame(results, index=test_query_structures)
        return df

In [5]:
datasets = read_dataset(complex_dataset_loc)
datasets = dataset_to_device(shuffle_datasets(datasets), model.device)

In [6]:
extension_df = test_batch(model, datasets['test-easy'], model_inverses=True, test_query_structures=test_query_structures)
extension_df*100

  0%|          | 2/1631 [00:00<01:36, 16.96it/s]

Running query : 1p


100%|██████████| 1631/1631 [07:20<00:00,  3.71it/s]
  0%|          | 0/489 [00:00<?, ?it/s]

Running query : 2p


100%|██████████| 489/489 [09:03<00:00,  1.11s/it]
  0%|          | 0/487 [00:00<?, ?it/s]

Running query : 3p


100%|██████████| 487/487 [02:06<00:00,  3.86it/s]
  1%|          | 5/426 [00:00<00:16, 26.09it/s]

Running query : 2i


100%|██████████| 426/426 [00:19<00:00, 21.58it/s]
  1%|          | 3/320 [00:00<00:12, 25.64it/s]

Running query : 3i


100%|██████████| 320/320 [00:13<00:00, 22.88it/s]
  0%|          | 1/455 [00:00<00:59,  7.58it/s]

Running query : ip


100%|██████████| 455/455 [01:05<00:00,  6.92it/s]
  0%|          | 1/429 [00:00<01:08,  6.27it/s]

Running query : pi


100%|██████████| 429/429 [00:35<00:00, 12.20it/s]


Unnamed: 0,1,3,5,10,mrr
1p,1.876706,5.326477,7.221794,11.578818,5.211004
2p,0.071693,0.200202,0.253874,1.010842,0.436865
3p,0.1134,0.311375,0.499573,0.56586,0.391315
2i,0.251105,0.448961,0.665198,1.313097,0.644156
3i,0.012945,0.495327,0.689878,1.373284,0.596254
ip,0.105085,0.112908,0.163072,0.295255,0.268965
pi,0.084397,0.179,0.188597,0.703355,0.338161
