In [1]:
import os
import json
import pandas as pd
import torch
import numpy as np 
from parse import *

In [2]:
datasets = ['FB15k-237','NELL995']
model_names = ['SheafE_Translational', 'SheafE_Multisection']
results_loc = '../data/{}'
complex_metrics = ['1','10','mrr']
groupby_cols = ['model','embdim','esdim','sec','orthogonal']
resnames = ['1p','2p','3p','2i','3i','ip','pi']
orthogonals = [0,0.1,0.01]

In [3]:
def infer_params_from_filename(fname):
#     print(fname)
    shared_names = ['class','model','embdim','esdim','sec','norm','lbda']
    parse_str_seed_orthogonal = '{}_{}_{:d}embdim_{:d}esdim_{:d}sec_{:d}norm_{:f}lbda_{:f}orthogonal_{:d}epochs_{}loss_{}_{}seed_{:d}-{:d}'
    seed_orthogonal_names = shared_names + ['orthogonal','epochs','loss','sampler','seed','date','time']
    
    parse_str_original_orthogonal = '{}_{}_{:d}embdim_{:d}esdim_{:d}sec_{:d}norm_{:f}lbda_{:f}orthogonal_{:d}epochs_{}loss_{}_{:d}-{:d}'
    original_orthogonal_names = [n for n in seed_orthogonal_names if n != 'seed']
    
    parse_str_seed = '{}_{}_{:d}embdim_{:d}esdim_{:d}sec_{:d}norm_{:f}lbda_{:d}epochs_{}loss_{}_{}seed_{:d}-{:d}'
    seed_names = [n for n in seed_orthogonal_names if n != 'orthogonal']
    
    parse_str_original = '{}_{}_{:d}embdim_{:d}esdim_{:d}sec_{:d}norm_{:f}lbda_{:d}epochs_{}loss_{}_{:d}-{:d}'
    original_names = [n for n in seed_names if n != 'seed']
    
    parse_options = [parse_str_seed_orthogonal, parse_str_original_orthogonal, parse_str_seed, parse_str_original]
    parse_names = [seed_orthogonal_names, original_orthogonal_names, seed_names, original_names]
    for parse_option, parse_name in zip(parse_options,parse_names):
        parsed = parse(parse_option, fname)
        if parsed is not None:
            return {parse_name[i]:parsed[i] for i in range(len(parse_name))}
    
    print('ignoring', fname)
    

In [4]:
results = []
complex_results = []
idx = 0
for dataset in datasets:
    dataset_dirname = results_loc.format(dataset)
    dataset_complex_dirname = os.path.join(dataset_dirname, 'complex')
    subdirs = [f.name for f in os.scandir(dataset_dirname) if f.is_dir()]
    for subdir in subdirs:
        for model_name in model_names:
            if model_name in subdir:
                complex_fname = os.path.join(dataset_complex_dirname, subdir + '.csv')
                params = infer_params_from_filename(subdir)
                pk_result_fname = os.path.join(dataset_dirname, subdir, 'results.json')
                if params is not None and os.path.exists(complex_fname) and os.path.exists(pk_result_fname):
                    
                    with open(pk_result_fname) as json_file:
                        pkr = json.load(json_file)
                        
                    cr = pd.read_csv(complex_fname, index_col=0)[complex_metrics].to_dict()
                    try:
                        pk_mrr = pkr['metrics']['inverse_harmonic_mean_rank']['both']['realistic']
                        pk_10 = pkr['metrics']['hits_at_k']['both']['realistic']['10']
                    except KeyError:
                        # old pykeen results format
                        pk_mrr = pkr['metrics']['mean_reciprocal_rank']['both']['avg']
                        pk_10 = pkr['metrics']['hits_at_k']['both']['avg']['10']
                    
                    r = {'id':idx,'dataset':dataset,'gc_mrr':pk_mrr,'gc_10':pk_10}
                    results.append({**r,**params,**cr})
                    idx += 1
# cdf = pd.concat(complex_results, ignore_index=True)
df = pd.DataFrame(results)
df['date'] = pd.to_datetime(df['date'], format='%Y%m%d')
df['orthogonal'] = df['orthogonal'].fillna(0)
df['seed'] = df['seed'].fillna(1234)
# df = df.merge(cdf, on='id', how='left')

In [5]:
tdf = df[(df['epochs'] == 250)&(df['loss'] == 'MarginRankingLoss')]
tdf

Unnamed: 0,id,dataset,gc_mrr,gc_10,class,model,embdim,esdim,sec,norm,...,orthogonal,epochs,loss,sampler,date,time,1,10,mrr,seed
0,0,FB15k-237,0.083767,0.119581,SheafE,Multisection,64,64,16,2,...,1.00,250,MarginRankingLoss,,2021-05-26,818,"{'1p': 0.014874080438167856, '2p': 8.717724710...","{'1p': 0.03534917533736199, '2p': 0.0006964905...","{'1p': 0.022465451246184414, '2p': 0.000738137...",1234
1,1,FB15k-237,0.126326,0.158381,SheafE,Multisection,64,64,32,2,...,10.00,250,MarginRankingLoss,,2021-05-29,203,"{'1p': 0.01870309253652432, '2p': 0.0002049592...","{'1p': 0.042139785790110657, '2p': 0.000927417...","{'1p': 0.02738468229616293, '2p': 0.0009014042...",22
2,2,FB15k-237,0.101673,0.258636,SheafE,Translational,64,64,64,2,...,0.00,250,MarginRankingLoss,,2021-05-27,2103,"{'1p': 6.195812456887471e-05, '2p': 0.00788861...","{'1p': 0.3079979677735141, '2p': 0.09891649810...","{'1p': 0.10815869263452142, '2p': 0.0408387581...",11
3,3,FB15k-237,0.091088,0.224092,SheafE,Translational,64,64,16,2,...,0.01,250,MarginRankingLoss,,2021-05-27,1456,"{'1p': 3.304433310339985e-05, '2p': 0.00216181...","{'1p': 0.09521311529580874, '2p': 0.0173046835...","{'1p': 0.03208190681743316, '2p': 0.0084451183...",11
4,4,FB15k-237,0.057308,0.132865,SheafE,Translational,64,64,64,2,...,0.10,250,MarginRankingLoss,,2021-05-31,2223,"{'1p': 2.891379146547487e-05, '2p': 0.00066310...","{'1p': 0.038558606190029696, '2p': 0.009276957...","{'1p': 0.014393511683882419, '2p': 0.003991357...",22
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
286,286,NELL995,0.000225,0.000175,SheafE,Multisection,32,32,32,2,...,10.00,250,MarginRankingLoss,,2021-05-27,2346,"{'1p': 0.004852538883142744, '2p': 3.216647652...","{'1p': 0.00961808005722516, '2p': 0.0001994321...","{'1p': 0.00676109242502187, '2p': 0.0002073673...",11
287,287,NELL995,0.000202,0.000193,SheafE,Multisection,32,32,8,2,...,10.00,250,MarginRankingLoss,,2021-05-29,741,"{'1p': 0.0255386608152652, '2p': 0.00176057848...","{'1p': 0.056993165846632714, '2p': 0.003150170...","{'1p': 0.03709597984492658, '2p': 0.0023547954...",22
288,288,NELL995,0.000281,0.000333,SheafE,Multisection,32,32,8,2,...,0.00,250,MarginRankingLoss,,2021-05-27,519,"{'1p': 0.24524654184106484, '2p': 3.0022044758...","{'1p': 0.6127635305603619, '2p': 0.00017798783...","{'1p': 0.3650828275584577, '2p': 0.00020045160...",11
289,289,NELL995,0.000715,0.000648,SheafE,Translational,32,32,32,2,...,10.00,250,MarginRankingLoss,,2021-06-02,2235,"{'1p': 0.00021266106659191308, '2p': 0.0009714...","{'1p': 0.0005219862543619683, '2p': 0.00118372...","{'1p': 0.00043227668866414187, '2p': 0.0011593...",33


In [6]:
fb_df = tdf[tdf['dataset'] == 'FB15k-237'].reset_index(drop=True)

In [7]:
nell_df = tdf[(tdf['dataset'] == 'NELL995')&(tdf['10'].apply(lambda x: x['1p'] > 0.001))].reset_index(drop=True)

In [8]:
def simplify_group(sv, embdim):
    sv = sv[sv.index.get_level_values('embdim') == embdim].droplevel('embdim',axis=0)
    sv = sv[sv.index.get_level_values('orthogonal').isin(orthogonals)]
    sv[resnames] = 100*sv[resnames].round(4)
    sv = sv[resnames]
    return sv

In [9]:
for complex_metric in complex_metrics:
    joined = fb_df.join(pd.DataFrame(fb_df[complex_metric].values.tolist()))
    grouped = joined.groupby(groupby_cols)
    embdim = 64
    sv = grouped.mean()
    sv = simplify_group(sv, embdim)
    sv.to_excel(f'FB15k-237/{complex_metric}.xlsx')
    sv = grouped.std()
    sv = simplify_group(sv, embdim)
    sv.to_excel(f'FB15k-237/{complex_metric}_std.xlsx')

In [10]:
grouped.count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,id,dataset,gc_mrr,gc_10,class,norm,lbda,epochs,loss,sampler,...,10,mrr,seed,1p,2p,3p,2i,3i,ip,pi
model,embdim,esdim,sec,orthogonal,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
Multisection,64,16,1,0.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,32,1,0.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,1,0.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,16,0.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,16,0.01,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,16,0.1,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,16,1.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,16,10.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,64,64,32,0.0,5,5,5,5,5,5,5,5,5,5,...,5,5,5,5,5,5,5,5,5,5
Multisection,64,64,32,0.01,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3


In [11]:
sv

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,1p,2p,3p,2i,3i,ip,pi
model,esdim,sec,orthogonal,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Multisection,16,1,0.0,0.16,0.06,0.11,0.23,0.19,0.01,0.01
Multisection,32,1,0.0,0.11,0.02,0.03,0.08,0.07,0.01,0.01
Multisection,64,1,0.0,0.04,0.02,0.09,0.09,0.09,0.01,0.01
Multisection,64,16,0.0,0.24,0.0,0.0,0.12,0.21,0.0,0.0
Multisection,64,16,0.01,7.28,0.04,0.03,4.78,2.78,0.02,0.01
Multisection,64,16,0.1,2.5,0.02,0.02,1.15,1.05,0.0,0.0
Multisection,64,32,0.0,0.1,0.32,0.52,0.21,0.41,0.07,0.06
Multisection,64,32,0.01,3.64,0.03,0.02,1.22,0.7,0.01,0.01
Multisection,64,32,0.1,2.39,0.01,0.01,0.91,0.82,0.0,0.0
Multisection,64,64,0.0,0.13,0.28,0.44,0.41,0.45,0.07,0.06


In [12]:
for complex_metric in complex_metrics:
    joined = nell_df.join(pd.DataFrame(nell_df[complex_metric].values.tolist()))
    grouped = joined.groupby(groupby_cols)
    embdim = 32
    sv = grouped.mean()
    sv = simplify_group(sv, embdim)
    sv.to_excel(f'NELL995/{complex_metric}.xlsx')
    sv = grouped.std()
    sv = simplify_group(sv, embdim)
    sv.to_excel(f'NELL995/{complex_metric}_std.xlsx')

In [13]:
grouped.count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,id,dataset,gc_mrr,gc_10,class,norm,lbda,epochs,loss,sampler,...,10,mrr,seed,1p,2p,3p,2i,3i,ip,pi
model,embdim,esdim,sec,orthogonal,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
Multisection,32,8,1,0.0,4,4,4,4,4,4,4,4,4,4,...,4,4,4,4,4,4,4,4,4,4
Multisection,32,16,1,0.0,4,4,4,4,4,4,4,4,4,4,...,4,4,4,4,4,4,4,4,4,4
Multisection,32,32,1,0.0,4,4,4,4,4,4,4,4,4,4,...,4,4,4,4,4,4,4,4,4,4
Multisection,32,32,8,0.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,32,32,8,0.01,6,6,6,6,6,6,6,6,6,6,...,6,6,6,6,6,6,6,6,6,6
Multisection,32,32,8,0.1,6,6,6,6,6,6,6,6,6,6,...,6,6,6,6,6,6,6,6,6,6
Multisection,32,32,8,1.0,6,6,6,6,6,6,6,6,6,6,...,6,6,6,6,6,6,6,6,6,6
Multisection,32,32,8,10.0,6,6,6,6,6,6,6,6,6,6,...,6,6,6,6,6,6,6,6,6,6
Multisection,32,32,16,0.0,3,3,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
Multisection,32,32,16,0.01,6,6,6,6,6,6,6,6,6,6,...,6,6,6,6,6,6,6,6,6,6


In [14]:
sv

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,1p,2p,3p,2i,3i,ip,pi
model,esdim,sec,orthogonal,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Multisection,8,1,0.0,0.14,0.0,0.02,0.05,0.01,0.0,0.0
Multisection,16,1,0.0,0.22,0.0,0.0,0.03,0.01,0.0,0.0
Multisection,32,1,0.0,0.16,0.0,0.0,0.07,0.1,0.0,0.0
Multisection,32,8,0.0,0.29,0.0,0.0,0.05,0.28,0.0,0.0
Multisection,32,8,0.01,0.16,0.21,0.2,0.04,0.05,0.08,0.19
Multisection,32,8,0.1,0.25,0.08,0.08,0.02,0.03,0.02,0.1
Multisection,32,16,0.0,0.23,0.0,0.0,0.09,0.12,0.0,0.0
Multisection,32,16,0.01,0.15,0.18,0.17,0.03,0.02,0.06,0.18
Multisection,32,16,0.1,0.17,0.07,0.07,0.03,0.06,0.02,0.09
Multisection,32,32,0.0,0.25,0.0,0.0,0.08,0.29,0.0,0.0
