In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter
import src.kgbench as kg


from kgbench import load, tic, toc, d


import numpy as np


In [2]:
import os

# Get the current directory
current_dir = os.getcwd()

# Get the parent directory
parent_dir = os.path.dirname(current_dir)

# Check if the current directory is already the parent directory
if current_dir != '/Users/macoftraopia/Documents/GitHub/RGCN-Explainer':
    # Set the parent directory as the current directory
    os.chdir(parent_dir)
!pwd

/Users/macoftraopia/Documents/GitHub/RGCN-Explainer


In [137]:
def keep_columns_with_non_zero_values(df):
    ''' Keep only columns with non-zero values'''

    df = df.fillna(0)
    # Get the column names with non-zero values
    non_zero_columns = df.columns[df.astype(bool).any(axis=0)]

    # Create a new DataFrame with only the columns containing non-zero values
    modified_df = df[non_zero_columns]

    return modified_df


def accuracy_func(df, comparison='prediction_full'):
    ''' Compute per class accuracy
    Args:
        df: dataframe containing the results of the experiments
        comparison: full, explain, explain binary, inverse
        binary: True if binary explanation is considered, False otherwise
        inverse: True if inverse explanation is considered, False otherwise
        overall: True if overall accuracy is considered, False otherwise    
    '''
    mismatch, matc = {}, {}
    match_count, mismatch_count = 0, 0

    
    for index, row in df.iterrows():

        c = np.argmax([float(num) for num in row[comparison][1:-1].split()])

        #original = int(row['label'])
        original = np.argmax(row['prediction_full'][1:-1].split())

        match_count, mismatch_count = (match_count + 1, mismatch_count) if original == c else (match_count, mismatch_count + 1)
        match_dict = matc if original == c else mismatch
        match_dict.setdefault(original, []).append(index)


    #labels = [int(i) for i in set(df['label'])]
    labels = [np.argmax([float(num) for num in row['prediction_full'][1:-1].split()]) for index, row in df.iterrows()]

    accuracy = {i: len(matc[i]) / (len(matc[i]) + len(mismatch[i])) if i in mismatch.keys() and i in matc.keys() else (1 if i not in mismatch.keys() else 0) for i in labels}
    #accuracy = {key: value for key, value in accuracy.items()}
    accuracy = dict(sorted(accuracy.items()))


    accuracy = list(accuracy.values())

    return mismatch, matc, accuracy

def fidelity(df, modality ,comparison_minus=None, comparison_inverse=None):
    ''' Compute fidelity as defined in the paper
    Args:
        df: dataframe containing the results of the experiments
         modality: minus or plus depending on the type of fidelity to compute
    '''

    mismatch_f, matc_f, accuracy_f = accuracy_func(df, comparison ='prediction_full')
    if modality == 'minus':
        mismatch_e, matc_e, accuracy_e = accuracy_func(df, comparison=comparison_minus)
        fidelity = {i: accuracy_f[i] - accuracy_e[i] for i in range(len(accuracy_f))}
        fidelity = {key: value for key, value in fidelity.items()}
        fidelity = dict(sorted(fidelity.items()))
    if modality == 'plus':
        mismatch_e, matc_e, accuracy_e = accuracy_func(df, comparison=comparison_inverse)
        fidelity = {i: accuracy_f[i] - accuracy_e[i] for i in range(len(accuracy_f))}
        fidelity = {key: value for key, value in fidelity.items()}
        fidelity = dict(sorted(fidelity.items()))

    return fidelity


def result_table_norel_int(path):
    d = pd.read_csv(path+'/Relations_Important_full_threshold.csv', sep=',')
    d.set_index('node_idx', inplace=True)
    d['label'] = d['label'].apply(lambda x: int(x[1]))

    fidelity_plus_threshold = np.round(list(fidelity(d, modality = 'plus', comparison_inverse= 'res_threshold_lekker_inverse').values()),3)
    fidelity_minus_threshold = np.round(list(fidelity(d, modality = 'minus', comparison_minus = 'prediction_threshold_lekker').values()),3)
    df = pd.DataFrame()
    sparsity = np.round(d.groupby('label')['sparsity_threshold'].mean(),3)
    df['Sparsity'] = sparsity
    df['Fidelity- '] = 1 - fidelity_minus_threshold
    df['Fidelity+ '] = fidelity_plus_threshold


    fidelity_plus_random = np.round(list(fidelity(d, modality = 'plus', comparison_inverse= 'res_random_inverse').values()),3)
    fidelity_minus_random = np.round(list(fidelity(d, modality = 'minus', comparison_minus = 'prediction_random').values()),3)
    df['Fidelity- Random'] = 1 - fidelity_minus_random
    df['Fidelity+ Random'] = fidelity_plus_random

    

    
    table = df.to_latex(index=True, caption = path.split('/')[-2].replace('_','-'), label = path.split('/')[-2].replace('_','-'),column_format='|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|')
    latex_table = table.replace('\\midrule', '\\hline')
    latex_table = latex_table.replace('\\toprule', '\\hline')
    latex_table = latex_table.replace('\\bottomrule', '\\hline')
    latex_table = latex_table.replace('\\begin{tabular}', '\\begin{adjustbox}{scale=0.5}\\begin{tabular}')  # Add scaling parameter
    latex_table = latex_table.replace('\\end{tabular}', '\\end{tabular}\\end{adjustbox}')  # Close the adjustbox environment
    return latex_table

def metrics(path):
    d = pd.read_csv(path+'/Relations_Important_full_threshold.csv', sep=',')


    fidelity_plus_threshold = np.round(np.round(list(fidelity(d, modality = 'plus', comparison_inverse= 'res_threshold_lekker_inverse').values()),3).mean(),3)
    fidelity_minus_threshold = np.round(1-np.round(list(fidelity(d, modality = 'minus', comparison_minus = 'prediction_threshold_lekker').values()),3).mean(),3)
    sparsity = np.round(d['sparsity_threshold'].mean(),3)


    fidelity_plus_random = np.round(np.round(list(fidelity(d, modality = 'plus', comparison_inverse= 'res_random_inverse').values()),3).mean(),3)
    fidelity_minus_random = np.round((1-np.round(list(fidelity(d, modality = 'minus', comparison_minus = 'prediction_random').values()),3)).mean(),3)

    score = sparsity + fidelity_plus_threshold + fidelity_minus_threshold 
    
    return score, sparsity, fidelity_minus_threshold, fidelity_plus_threshold, fidelity_minus_random, fidelity_plus_random

def table_metrics_overview(name, exp):
    init = ['normal','const','overall_frequency','relative_frequency','inverse_relative_frequency','domain_frequency','range_frequency']
    #init = ['normal','overall_frequency','relative_frequency','domain_frequency']
    df = pd.DataFrame(columns = ['init','Score','Sparsity', 'Fidelity-', 'Fidelity+', 'Fidelity- random', 'Fidelity+ random'])
    for i in init:
        #m = metrics(f'chk/{name}_chk/exp/init_{i}_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance')
        m = metrics(f'chk/{name}_chk/exp/init_{i}_{exp}/Relation_Importance')
        df.loc[len(df)] = [i] + list(m)
    df.set_index('init', inplace=True)
    table = df.to_latex(index=True, caption = name, label = name,column_format='|c|c|c|c|c|c|c|')
    latex_table = table.replace('\\midrule', '\\hline')
    latex_table = latex_table.replace('\\toprule', '\\hline')
    latex_table = latex_table.replace('\\bottomrule', '\\hline')
    latex_table = latex_table.replace('\\begin{tabular}', '\\begin{adjustbox}{scale=0.5}\\begin{tabular}')  # Add scaling parameter
    latex_table = latex_table.replace('\\end{tabular}', '\\end{tabular}\\end{adjustbox}')  # Close the adjustbox environment
    print(latex_table)
    return latex_table



In [48]:
#STAT TESTS

from scipy.stats import f_oneway

def anova_relations(path,relations):
    ''' Perform ANOVA test to check if relations have significantly different frequencies between classes
    '''

    d = pd.read_csv(path+'/Relations_Important_full_threshold.csv', sep=',')

    df = d[relations]

    #df['label'] = d['label'].apply(lambda x: int(x[1]))
    df.loc[:, 'label'] = d['label'].apply(lambda x: int(x[1])).values

    df = df.groupby('label', as_index=False).mean()
    #df = keep_columns_with_non_zero_values(df)
    data = df.to_dict()

    classes = sorted(set(data['label'].values()))
    relation_freq_lists = {relation: [data[relation][i] for i in classes] for relation in data if relation != 'label'}

    # Perform the ANOVA test
    anova_results = {}
    for relation, freq_list in relation_freq_lists.items():
        anova_result = f_oneway(*[freq_list for _ in classes])
        anova_results[relation] = anova_result

    # Extract p-values from the ANOVA results
    p_values = {relation: result.pvalue for relation, result in anova_results.items()}

    # Define the significance level
    alpha = 0.9

    # Identify relations that have significantly different frequencies between classes
    significant_relations = [relation for relation, p_value in p_values.items() if p_value < alpha]

    # Print the results
    for relation in relation_freq_lists:
        if relation in significant_relations:
            print(f"The relation '{relation}' has significantly different frequencies between classes.")
        else:
            #print(f"The relation '{relation}' does not have significantly different frequencies between classes.")
            #print(f'{relation} not significant')
            pass


def anova_relations_inits(name):
    init = ['normal','const','overall_frequency','relative_frequency','inverse_relative_frequency','domain_frequency','range_frequency']

    for i in init:
        print(i)
        m = anova_relations(f'chk/{name}_chk/exp/init_{i}_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance')




In [169]:
from src.rgcn_explainer_utils import *
name = 'mdgenre'
if name in ['aifb', 'mutag', 'bgs', 'am', 'mdgenre']:
    data = kg.load(name, torch=True, final=False)
if 'IMDb' in name:    
    data = torch.load(f'data/IMDB/finals/{name}.pt')
if 'dbo' in name:
    data = torch.load(f'data/DBO/finals/{name}.pt')
get_relations(data)
relations = [data.i2rel[i][0] for i in range(len(data.i2rel))]
data.triples = torch.Tensor(data.triples).to(int)
data.withheld = torch.Tensor(data.withheld).to(int)
data.training = torch.Tensor(data.training).to(int)
print('rel:', data.num_relations, 'ent:', data.num_entities, 'triples:', data.triples.shape)
print('training', data.training.shape, 'withheld', data.withheld.shape)

loaded data mdgenre (70.49s).
rel: 154 ent: 349344 triples: torch.Size([1252247, 3])
training torch.Size([4863, 2]) withheld torch.Size([500, 2])


In [170]:
#ALl initializations
if name=='aifb':
    exp = 'hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no'
if name == 'mutag':
    exp = 'hops_2_lr_0.5_adaptive_False_size_0.0005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no'
table_metrics_overview(name, exp)

#anova_relations_inits(name)

FileNotFoundError: [Errno 2] No such file or directory: 'chk/mdgenre_chk/exp/init_normal_hops_2_lr_0.5_adaptive_False_size_0.0005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance/Relations_Important_full_threshold.csv'

In [164]:
init = 'normal'
path = f'chk/{name}_chk/exp/init_{init}_{exp}/Relation_Importance'
path = 'chk/dbo_gender_chk/exp/init_relative_frequency_hops_2_lr_0.5_adaptive_False_size_0.0005_sizestd_adaptive_ent_10_type_1_killtype_False_break_no/Relation_Importance'
path = 'chk/dbo_gender_chk/exp/init_normal_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'

path = 'chk/dbo_gender_chk/exp/init_overall_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
path = 'chk/dbo_gender_chk/exp/init_relative_frequency_hops_2_lr_0.5_adaptive_False_size_0.0005_sizestd_adaptive_ent_10_type_1_killtype_False_break_no/Relation_Importance'
path = 'chk/dbo_gender_chk/exp/init_domain_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
#path = 'chk/dbo_gender_chk/exp/init_inverse_relative_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
#path = 'chk/dbo_gender_chk/exp/init_normal_hops_2_lr_0.1_adaptive_False_size_0.0005_sizestd_adaptive_ent_1_type_1_killtype_False_break_no/Relation_Importance'
path = 'chk/dbo_gender_chk/exp/init_range_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
print(result_table_norel_int(path) )
#print(anova_relations(path,relations) )

\begin{table}
\centering
\caption{init-range-frequency-hops-2-lr-0.5-adaptive-False-size-0.005-sizestd-adaptive-ent-10-type-1-killtype-True-break-no}
\label{init-range-frequency-hops-2-lr-0.5-adaptive-False-size-0.005-sizestd-adaptive-ent-10-type-1-killtype-True-break-no}
\begin{adjustbox}{scale=0.5}\begin{tabular}{|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|}
\hline
{} &  Sparsity &  Fidelity-  &  Fidelity+  &  Fidelity- Random &  Fidelity+ Random \\
label &           &             &             &                   &                   \\
\hline
0     &     0.378 &           1 &       0.099 &             0.975 &             0.006 \\
1     &     0.180 &           1 &       0.921 &             0.955 &             0.910 \\
\hline
\end{tabular}\end{adjustbox}
\end{table}



  table = df.to_latex(index=True, caption = path.split('/')[-2].replace('_','-'), label = path.split('/')[-2].replace('_','-'),column_format='|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|')


In [165]:

init = ['normal','overall_frequency','relative_frequency','inverse_relative_frequency','domain_frequency','range_frequency']
paths = ['chk/dbo_gender_chk/exp/init_normal_hops_2_lr_0.1_adaptive_False_size_0.0005_sizestd_adaptive_ent_1_type_1_killtype_False_break_no/Relation_Importance',
        'chk/dbo_gender_chk/exp/init_overall_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance',
         'chk/dbo_gender_chk/exp/init_relative_frequency_hops_2_lr_0.5_adaptive_False_size_0.0005_sizestd_adaptive_ent_10_type_1_killtype_False_break_no/Relation_Importance',
         'chk/dbo_gender_chk/exp/init_inverse_relative_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance',
         'chk/dbo_gender_chk/exp/init_domain_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance',
         'chk/dbo_gender_chk/exp/init_range_frequency_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
         
             ]
df = pd.DataFrame(columns = ['init','Score','Sparsity', 'Fidelity-', 'Fidelity+', 'Fidelity- random', 'Fidelity+ random'])
for i,j in zip(init,paths):
    #m = metrics(f'chk/{name}_chk/exp/init_{i}_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance')
    m = metrics(j)
    df.loc[len(df)] = [i] + list(m)
df.set_index('init', inplace=True)
table = df.to_latex(index=True, caption = name, label = name,column_format='|c|c|c|c|c|c|c|')
latex_table = table.replace('\\midrule', '\\hline')
latex_table = latex_table.replace('\\toprule', '\\hline')
latex_table = latex_table.replace('\\bottomrule', '\\hline')
latex_table = latex_table.replace('\\begin{tabular}', '\\begin{adjustbox}{scale=0.5}\\begin{tabular}')  # Add scaling parameter
latex_table = latex_table.replace('\\end{tabular}', '\\end{tabular}\\end{adjustbox}')  # Close the adjustbox environment
print(latex_table)


\begin{table}
\centering
\caption{dbo_gender}
\label{dbo_gender}
\begin{adjustbox}{scale=0.5}\begin{tabular}{|c|c|c|c|c|c|c|}
\hline
{} &  Score &  Sparsity &  Fidelity- &  Fidelity+ &  Fidelity- random &  Fidelity+ random \\
init                       &        &           &            &            &                   &                   \\
\hline
normal                     &  1.836 &     0.805 &      1.000 &      0.031 &             0.568 &             0.037 \\
overall\_frequency          &  1.792 &     0.287 &      1.000 &      0.505 &             0.954 &             0.447 \\
relative\_frequency         &  1.580 &     0.602 &      0.736 &      0.242 &             0.742 &             0.242 \\
inverse\_relative\_frequency &  1.789 &     0.284 &      1.000 &      0.505 &             0.960 &             0.452 \\
domain\_frequency           &  1.817 &     0.329 &      0.989 &      0.499 &             0.948 &             0.447 \\
range\_frequency            &  1.799 &     0.289 &      1.00

  table = df.to_latex(index=True, caption = name, label = name,column_format='|c|c|c|c|c|c|c|')


In [171]:
init = 'overall_frequency'
path = 'chk/mdgenre_chk/exp/init_normal_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/init_normal_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
#path = 'chk/dbo_gender_chk/exp/init_relative_frequency_hops_2_lr_0.5_adaptive_False_size_0.0005_sizestd_adaptive_ent_10_type_1_killtype_False_break_no/Relation_Importance'
print(result_table_norel_int(path) )
print(anova_relations(path,relations) )

ValueError: Length of values (8) does not match length of index (3)

In [173]:

def result_table_norel(path):
    d = pd.read_csv(path+'/Relations_Important_full_threshold.csv', sep=',')
    d.set_index('node_idx', inplace=True)
    d['label'] = d['label'].apply(lambda x: int(x[1]))
    df = pd.DataFrame()
    df['label'] = d['label']
    df['Sparsity'] = d['sparsity_threshold'].apply(lambda x: np.round(x,3))
    df['Fidelity-'] = d['fidelity_minus_threshold'].apply(lambda x: np.round(x,3))
    df['Fidelity+'] = d['fidelity_plus_threshold'].apply(lambda x: np.round(x,3))
    df['Fidelity- Random'] = d['fidelity_minus_random'].apply(lambda x: np.round(x,3))
    df['Fidelity+ Random'] = d['fidelity_plus_random'].apply(lambda x: np.round(x,3))
    df = df.groupby(by='label').mean()



    table = df.to_latex(index=True, caption = path.split('/')[-2].replace('_','-'), label = path.split('/')[-2].replace('_','-'),column_format='|c|m{3cm}|c|c|c|c|c|c|c|c|c|c|c|c|c|c|')
    latex_table = table.replace('\\midrule', '\\hline')
    latex_table = latex_table.replace('\\toprule', '\\hline')
    latex_table = latex_table.replace('\\bottomrule', '\\hline')
    latex_table = latex_table.replace('\\begin{tabular}', '\\begin{adjustbox}{scale=0.5}\\begin{tabular}')  # Add scaling parameter
    latex_table = latex_table.replace('\\end{tabular}', '\\end{tabular}\\end{adjustbox}')  # Close the adjustbox environment
    return latex_table

path = 'chk/mdgenre_chk/exp/init_normal_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/init_normal_hops_2_lr_0.5_adaptive_False_size_0.005_sizestd_adaptive_ent_10_type_1_killtype_True_break_no/Relation_Importance'
print(result_table_norel(path) )   

\begin{table}
\centering
\caption{init-normal-hops-2-lr-0.5-adaptive-False-size-0.005-sizestd-adaptive-ent-10-type-1-killtype-True-break-no}
\label{init-normal-hops-2-lr-0.5-adaptive-False-size-0.005-sizestd-adaptive-ent-10-type-1-killtype-True-break-no}
\begin{adjustbox}{scale=0.5}\begin{tabular}{|c|m{3cm}|c|c|c|c|c|c|c|c|c|c|c|c|c|c|}
\hline
{} &  Sparsity &  Fidelity- &  Fidelity+ &  Fidelity- Random &  Fidelity+ Random \\
label &           &            &            &                   &                   \\
\hline
0     &  0.988429 &   0.909571 &   0.083571 &          0.908429 &          0.083000 \\
1     &  0.986200 &   0.998800 &   0.031600 &          1.002800 &          0.034000 \\
2     &  0.988417 &   0.822541 &   0.164735 &          0.822159 &          0.164827 \\
\hline
\end{tabular}\end{adjustbox}
\end{table}



  table = df.to_latex(index=True, caption = path.split('/')[-2].replace('_','-'), label = path.split('/')[-2].replace('_','-'),column_format='|c|m{3cm}|c|c|c|c|c|c|c|c|c|c|c|c|c|c|')


In [176]:
#df = pd.read_csv(path+'/Relations_Important_full_threshold.csv', sep=',')
d = pd.read_csv(path+'/Relations_Important_full_threshold.csv', sep=',')
d.set_index('node_idx', inplace=True)
d['label'] = d['label'].apply(lambda x: int(x[1]))

fidelity_plus_threshold = np.round(list(fidelity(d, modality = 'plus', comparison_inverse= 'res_threshold_lekker_inverse').values()),3)
fidelity_minus_threshold = np.round(list(fidelity(d, modality = 'minus', comparison_minus = 'prediction_threshold_lekker').values()),3)
df = pd.DataFrame()
sparsity = np.round(d.groupby('label')['sparsity_threshold'].mean(),3)
df['Sparsity'] = sparsity
fidelity_minus_threshold

# df['Fidelity- '] = 1 - fidelity_minus_threshold

array([ 0.812,  0.   , -0.088,  0.175,  0.333,  0.552,  0.   ,  0.238])

In [4]:
from src.rgcn_explainer_utils import *
name = 'aifb'
if name in ['aifb', 'mutag', 'bgs', 'am', 'mdgenre', 'amplus', 'dmg777k']:
    data = kg.load(name, torch=True, final=False)
    data = prunee(data, 2)
data.withheld

loaded data aifb (0.2069s).


tensor([[5757,    2],
        [5797,    2],
        [5678,    0],
        [5900,    2],
        [5677,    2],
        [5731,    1],
        [5724,    0],
        [5791,    2],
        [5699,    0],
        [5857,    3],
        [5752,    3],
        [5688,    0],
        [5702,    0],
        [5714,    0],
        [5905,    1],
        [5795,    3],
        [5811,    2],
        [5708,    0],
        [5843,    0],
        [5873,    0],
        [5697,    0],
        [5753,    3],
        [5831,    2],
        [5839,    2],
        [5783,    0],
        [5755,    2],
        [5808,    1],
        [5844,    2],
        [5798,    3],
        [5701,    0],
        [5845,    0],
        [5861,    2],
        [5778,    0],
        [5854,    3],
        [5785,    1]], dtype=torch.int32)

In [67]:
v_threshold = torch.load('chk/aifb_chk/exp/init_normal_lr_0.5_size_0.0005_ent_1_type_1_wd_0.9_MFR_1/masked_adj/masked_ver_thresh5724')
h_threshold = torch.load('chk/aifb_chk/exp/init_normal_lr_0.5_size_0.0005_ent_1_type_1_wd_0.9_MFR_1/masked_adj/masked_hor_thresh5724')
node_idx = 5724
row_indices =  torch.nonzero(v_threshold.coalesce().indices()[1] == node_idx, as_tuple=False)[:, 0]
if v_threshold.coalesce().values()[row_indices].count_nonzero() != 0:
    print('node_idx in explanation')
else:
    print('node_idx not in explanation')
5724

node_idx not in explanation


5724

In [79]:
from src.rgcn_explainer_utils import *
v = torch.load('chk/aifb_chk/exp/init_normal_lr_0.5_size_0.0005_ent_1_type_1_wd_0.9_MFR_1/masked_adj/masked_ver5724')
row_indices =  torch.nonzero(v.coalesce().indices()[1] == node_idx, as_tuple=False)[:, 0]
if v.coalesce().values()[row_indices].count_nonzero() != 0:
    print('node_idx in explanation')
else:
    print('node_idx not in explanation')

node_idx in explanation


In [81]:
v.coalesce().values()[row_indices]

tensor([0.9581])

In [76]:
s = v_threshold.coalesce().indices()[0]%data.num_entities
r = torch.divide(v_threshold.coalesce().indices()[0], data.num_entities, rounding_mode = 'floor')
o = v_threshold.coalesce().indices()[1]
triples = torch.stack([s,r,o],dim=1)
triples
# if 56 in triples[:,0]:
#     print(triples[triples[:,0]==5678])
data.i2e[7327]

('http://www.aifb.uni-karlsruhe.de/Publikationen/viewPublikationOWL/id306instance',
 'iri')

In [71]:
s = h_threshold.coalesce().indices()[1]%data.num_entities
r = torch.divide(h_threshold.coalesce().indices()[0], data.num_entities, rounding_mode = 'floor')
o = h_threshold.coalesce().indices()[0]
triples = torch.stack([s,r,o],dim=1)
triples
# if 5678 in triples[:,0]:
#     print(triples[triples[:,0]==5678])


tensor([[   0,    0, 5677],
        [   0,    0, 5687],
        [3162,    0, 5693],
        [   0,    0, 5697],
        [   0,    0, 5699],
        [3162,    0, 5723],
        [   0,    0, 5754],
        [   0,    0, 5772],
        [   0,    0, 5805],
        [   0,    0, 5808],
        [   0,    0, 5816],
        [   0,    0, 5857],
        [   0,    0, 5858],
        [   0,    0, 5864],
        [6070,    0, 7327]])

In [66]:
data.i2e[6981]

('http://www.aifb.uni-karlsruhe.de/Publikationen/viewPublikationOWL/id1169instance',
 'iri')

In [24]:
def match_to_triples(v,h, data, sparse=True):
    """
    v: vertical adjacency matrix
    h: horizontal adjacency matrix
    data: dataset
    sparse: if True, the adjacency matrix is sparse, otherwise it is dense
    returns: the triples corresponding to the adjacency matrix (from stack indexes to original indexes)
    """

    n_ent = data.num_entities
    if sparse:
        pv,_ = torch.div(v.coalesce().indices(), n_ent, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
        sv,ov = v.coalesce().indices()% n_ent
        result_v = torch.stack([sv,pv,ov], dim=1)
        ph,_ = torch.div(h.coalesce().indices(),  n_ent, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
        sh,oh = h.coalesce().indices()% n_ent
        result_h = torch.stack([sh,ph,oh], dim=1)
        result = torch.cat((result_v, result_h), 0)


                    
    else:

        # _,ph = torch.div(h, data.num_entities, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
        # sh,oh = h%data.num_entities
        # result_h = torch.stack([sh,ph,oh], dim=1)

        # pv, _ = torch.div(v, data.num_entities, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
        # sv,ov = v%data.num_entities
        # result_v = torch.stack([sv,pv,ov], dim=1)

        # result = torch.cat((result_v, result_h), 0)

        if len(h )!= 0:
            _,ph = torch.div(h,  n_ent, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
            sh,oh = h% n_ent
            result_h = torch.stack([sh,ph,oh], dim=1)
        if len(v)!=0:
            pv, _ = torch.div(v,  n_ent, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
            sv,ov = v% n_ent
            result_v = torch.stack([sv,pv,ov], dim=1)
        if len(h) != 0 and len(v) != 0:
            result = torch.cat((result_v, result_h), 0)
            #print('all good')
        if len(h) == 0:
            result = result_v
            print('ph is empty')
        if len(v) == 0:
            result = result_h             
    
    return result
match_to_triples(v_threshold, h_threshold, data)


tensor([[6881,    2, 5678],
        [6881,    2, 5743],
        [6981,    2, 5745],
        [7045,    2, 5743],
        [7100,    2, 5746],
        [7393,    2, 5678],
        [7393,    2, 5745],
        [7393,    2, 5747],
        [7968,    2, 6404],
        [8013,    2, 5743],
        [8015,    2, 5746],
        [5357,   18, 5743],
        [5357,   18, 5745],
        [5357,   18, 5746],
        [5357,   18, 5777],
        [5408,   18, 5859],
        [5494,   18, 5745],
        [5494,   18, 5910],
        [5502,   21, 5837],
        [5937,   21, 5678],
        [5678,   30, 7045],
        [5678,   30, 7068],
        [5678,   30, 8013],
        [5745,   30, 6981],
        [5745,   30, 7393],
        [5746,   30, 7068],
        [6404,   30, 8013],
        [5746,   36, 5939],
        [5357,    0, 5678],
        [5357,    0, 5737],
        [5408,    0, 5744],
        [5413,    0, 5761],
        [5431,    0, 5879],
        [5431,    0, 5910],
        [5431,    0, 5913],
        [5450,    0,

In [72]:
v = v_threshold
h = h_threshold
pv,_ = torch.div(v.coalesce().indices(), data.num_entities, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
sv,ov = v.coalesce().indices()%data.num_entities
result_v = torch.stack([sv,pv,ov], dim=1)
ph,_ = torch.div(h.coalesce().indices(), data.num_entities, rounding_mode='floor')#v.coalesce().indices()//data.num_entities
sh,oh = h.coalesce().indices()%data.num_entities
result_h = torch.stack([sh,ph,oh], dim=1)
result = torch.cat((result_v, result_h), 0)
result


tensor([[7327,    2, 5723],
        [5686,   10,    0],
        [5693,   10,    0],
        [5701,   10,    0],
        [5714,   10,    0],
        [5735,   10,    0],
        [5772,   10,    0],
        [5786,   10,    0],
        [5891,   10,    0],
        [5700,   27,    0],
        [5723,   27,    0],
        [5772,   27,    0],
        [5882,   27,    0],
        [5891,   27,    0],
        [6069,   30, 7327],
        [5677,    0,    0],
        [5687,    0,    0],
        [5693,    0, 3162],
        [5697,    0,    0],
        [5699,    0,    0],
        [5723,    0, 3162],
        [5754,    0,    0],
        [5772,    0,    0],
        [5805,    0,    0],
        [5808,    0,    0],
        [5816,    0,    0],
        [5857,    0,    0],
        [5858,    0,    0],
        [5864,    0,    0],
        [7327,    0, 6070]])