In [1]:
import os
import sys
import time
import argparse
import json
from tqdm import tqdm

import pandas as pd
import numpy as np
import pykeen
from pykeen.evaluation import rank_based_evaluator
import torch
from scipy import stats

from sheaf_kg.train_sheafE_betae import read_dataset, shuffle_datasets, dataset_to_device, sample_answers
import sheaf_kg.complex_functions as cf
# from sheaf_kg.complex_functions import test_batch

PyKEEN was unable to load dataset openbiolinkf1. Try uninstalling PyKEEN with ``pip uninstall pykeen`` then reinstalling
PyKEEN was unable to load dataset openbiolinkf2. Try uninstalling PyKEEN with ``pip uninstall pykeen`` then reinstalling


In [2]:
dataset = 'FB15k-237'
savename = 'TransE_1000epochs_64embdim_SoftplusLossloss_1234seed_20210519-1452'

savedir = f'/home/gebhart/projects/sheaf_kg/data/{dataset}'
complex_dataset_loc = f'/home/gebhart/projects/sheaf_kg/data/{dataset}-betae'

test_query_structures = ['1p','2p','3p','2i','3i','ip','pi']


In [3]:
model = torch.load(os.path.join(savedir, savename, 'trained_model.pkl'))

In [7]:
fun_map = {'1p':cf.L_p_traversal_transE,
           '2p':cf.L_p_traversal_transE,
           '3p':cf.L_p_traversal_transE,
           '2i':cf.L_i_traversal_transE,
           '3i':cf.L_i_traversal_transE,
           'ip':cf.L_ip_traversal_transE,
           'pi':cf.L_pi_traversal_transE
          }

def test_batch(model, test_data, model_inverses=False, test_batch_size=5,
                test_query_structures=test_query_structures,
                ks=[1,3,5,10]):
    with torch.no_grad():
        results = []
        for query_structure in test_query_structures:
            print('Running query : {}'.format(query_structure))
            all_avg_ranks = []
            cnt = 0
            num_test = len(test_data[query_structure]['answers'])
            for qix in tqdm(range(0, num_test//2, test_batch_size)):
                if num_test - qix == 1:
                    continue
                entities = test_data[query_structure]['entities'][qix:qix+test_batch_size]
                relations = test_data[query_structure]['relations'][qix:qix+test_batch_size]
                if model_inverses:
                    inverses = None
                else:
                    inverses = test_data[query_structure]['inverses'][qix:qix+test_batch_size]
                all_answers = test_data[query_structure]['answers'][qix:qix+test_batch_size]
                targets = torch.arange(model.entity_embeddings(indices=None).shape[0]).to(model.device)
                Q = fun_map[query_structure](model, entities, relations, targets, invs=inverses, p=1)
                for b in range(Q.shape[0]):
                    answers = all_answers[b]
                    for i in range(len(answers)):
                        a = answers[i]
                        msk = torch.ones(Q.shape[1], dtype=bool)
                        msk[answers] = False
                        msk[a] = True
                        rank = rank_based_evaluator.compute_rank_from_scores(Q[b,a].unsqueeze(0), Q[b,msk].unsqueeze(0))
                        avg_rank = rank['realistic'].cpu().numpy()
                        all_avg_ranks.append(avg_rank[0])
            all_avg_ranks = np.array(all_avg_ranks)
#                 max_len = len(max(all_answers, key=len))
#                 mask = np.ones(Q.shape, dtype=bool)
#                 for i in range(max_len):
#                     tmask = mask.copy()
#                     answers = []
#                     idxs = []
#                     for j in range(len(all_answers)):
#                         a = all_answers[j]
#                         if len(a) > i:
#                             answers.append(a[i])
#                             idxs.append(j)
#                             tmask[j,] = 0
#                         else:
#                             tmask[j,:] = 0
# #                     answers = [a[i] if len(a) > i else a[-1] for a in all_answers ]
# #                     idxs = np.arange(len(answers))
# #                     answers = [a[i] for a in all_answers if len(a) > i]
# #                     idxs = [j for j in range(len(all_answers)) if len(all_answers[j]) > i]
#                     if len(answers) > 0:
#                         print(tmask)
#                         ranks = rank_based_evaluator.compute_rank_from_scores(Q[np.vstack((idxs, answers))].unsqueeze(1), Q[tmask])
#                         avg_rank = ranks['realistic'].cpu().numpy()
#                         all_avg_ranks.append(avg_rank - i + 1)
#             all_avg_ranks = np.concatenate(all_avg_ranks)
            
            rd = {k: np.mean(all_avg_ranks <= k) for k in ks}
            mrr = np.reciprocal(stats.hmean(all_avg_ranks))
            rd['mrr'] = mrr if isinstance(mrr, float) else mrr[0]
            # rd['mr'] = np.mean(all_avg_ranks)
            results.append(rd)

        df = pd.DataFrame(results, index=test_query_structures)
        return df

In [8]:
datasets = read_dataset(complex_dataset_loc)
datasets = dataset_to_device(shuffle_datasets(datasets), model.device)

In [9]:
extension_df = test_batch(model, datasets['test-easy'], model_inverses=True, test_query_structures=test_query_structures)
extension_df*100

  0%|          | 8/1631 [00:00<00:21, 76.66it/s]

Running query : 1p


100%|██████████| 1631/1631 [03:10<00:00,  8.55it/s]
  0%|          | 0/489 [00:00<?, ?it/s]

Running query : 2p


100%|██████████| 489/489 [04:40<00:00,  1.74it/s]
  0%|          | 0/487 [00:00<?, ?it/s]

Running query : 3p


100%|██████████| 487/487 [07:03<00:00,  1.15it/s]
  0%|          | 0/426 [00:00<?, ?it/s]

Running query : 2i


100%|██████████| 426/426 [00:53<00:00,  7.95it/s]
  0%|          | 0/320 [00:00<?, ?it/s]

Running query : 3i


100%|██████████| 320/320 [00:52<00:00,  6.08it/s]
  0%|          | 0/455 [00:00<?, ?it/s]

Running query : ip


100%|██████████| 455/455 [04:15<00:00,  1.78it/s]
  0%|          | 0/429 [00:00<?, ?it/s]

Running query : pi


100%|██████████| 429/429 [01:50<00:00,  3.88it/s]


Unnamed: 0,1,3,5,10,mrr
1p,1.400165,3.759938,5.508668,8.765593,3.836776
2p,0.150866,0.319384,0.47025,0.746322,0.431486
3p,0.062978,0.149572,0.217114,0.356988,0.246314
2i,0.155744,0.464329,0.675212,1.114389,0.58375
3i,0.121675,0.407741,0.636852,1.025176,0.644173
ip,0.015701,0.070207,0.133012,0.340716,0.231943
pi,0.040637,0.10498,0.152874,0.311554,0.303205
