# Import Libraries

In [1]:
import os
import pickle
import pandas as pd
import networkx as nx
from collections import Counter
import random

import data_params as input_data_params

# Get Explanation Paths

In [2]:
def add_dict_counter(counting_dict: dict, key: str, add_val: int):
    if key in counting_dict:
        counting_dict[key] += add_val
    else:
        counting_dict[key] = add_val

In [3]:
DISEASE_PREFIX = input_data_params.disease
assert DISEASE_PREFIX == 'dmd' or 'hd' or 'oi'

DATASET_PREFIXES = ['prev', 'restr']
embedding_method = 'e2v'

seeded_emb = False

if seeded_emb:
    fixed_emb = '_seeded'
else:
    fixed_emb = ''

expl_folders = input_data_params.expl_folders
explanations_per_dataset = {}
explanation_pairs_per_dataset = {}

curr_working_dir = os.path.dirname(os.getcwd())
curr_output_dir = os.path.join(curr_working_dir, 'output', DISEASE_PREFIX)

chosen_explanations_per_dataset = {}

complete_explanation_counts_per_dataset = {}
incomplete_explanation_counts_per_dataset = {}

for dataset_prefix in DATASET_PREFIXES:
    complete_explanation_counts = {}
    incomplete_explanation_counts = {}

    nodes = pd.read_csv(f'../output/{DISEASE_PREFIX}/{dataset_prefix}_{DISEASE_PREFIX}_indexed_nodes.csv')
    
    dataset_output_dir = os.path.join(curr_output_dir, f'{dataset_prefix}_{embedding_method}{fixed_emb}', expl_folders[dataset_prefix])
    print(dataset_output_dir)
    
    all_explanations = []
    all_graphs = []
    all_pairs = []

    for item in os.listdir(dataset_output_dir):
        if '.gpickle' in item:
            
            with open(os.path.join(dataset_output_dir, item), 'rb') as f:
                G = pickle.load(f)

                if 'incomplete' not in item:
                    all_graphs.append(G)
                
            file_name_explanation = item.split('_graph.gpickle')[0]
            all_explanations.append(file_name_explanation)
            pair_file_name = f'{file_name_explanation}_pair.pkl'
            
            with open(os.path.join(dataset_output_dir, pair_file_name), 'rb') as f:
                loaded_info = pickle.load(f)
                
                symptom_index = nodes.loc[nodes['id'] == loaded_info['symptom_id']].index[0]
                symptom_label = nodes.loc[nodes['id'] == loaded_info['symptom_id']]['label'].values[0]
                
                drug_index = nodes.loc[nodes['id'] == loaded_info['drug_id']].index[0]
                drug_label = nodes.loc[nodes['id'] == loaded_info['drug_id']]['label'].values[0]
                
                if 'incomplete' not in item:
                    all_pairs.append([f'{symptom_label} {symptom_index}', f'{drug_label} {drug_index}'])
                    add_dict_counter(counting_dict=complete_explanation_counts, key=f'{loaded_info["drug_id"]} {loaded_info["symptom_id"]}', add_val=1)
                else:
                    add_dict_counter(counting_dict=incomplete_explanation_counts, key=f'{loaded_info["drug_id"]} {loaded_info["symptom_id"]}', add_val=1)
            
    explanations_per_dataset[dataset_prefix] = all_graphs
    explanation_pairs_per_dataset[dataset_prefix] = all_pairs

    complete_explanation_counts_per_dataset[dataset_prefix] = complete_explanation_counts
    incomplete_explanation_counts_per_dataset[dataset_prefix] = incomplete_explanation_counts

c:\Users\rzwart\Documents\GitHub\XAI-FO\output\oi\prev_e2v\expl_9
c:\Users\rzwart\Documents\GitHub\XAI-FO\output\oi\restr_e2v\expl_8


In [4]:
complete_explanation_counts_per_dataset

{'prev': {'80 HP:0002829': 9},
 'restr': {'4072 HP:0000978': 2, '4072 HP:0005692': 3}}

In [5]:
incomplete_explanation_counts_per_dataset

{'prev': {},
 'restr': {'4072 HP:0002119': 8,
  '2552 HP:0004586': 8,
  '4072 HP:0005692': 5,
  '4072 HP:0001371': 9,
  '4072 HP:0001382': 8,
  '2552 HP:0002645': 8,
  '2552 HP:0002757': 8,
  '2552 HP:0000883': 8,
  '2552 HP:0002659': 8,
  '2552 HP:0000444': 8,
  '4072 HP:0000978': 6,
  '2552 HP:0000703': 9,
  '760 HP:0000365': 8,
  '2351 HP:0002823': 8,
  '2552 HP:0000365': 8,
  '4072 HP:0002953': 8,
  '2552 HP:0002953': 8,
  '4072 HP:0006487': 8}}

# Number of Types of Edges, Nodes, Metapaths

In [6]:
def get_occurrence_df(c, label_name):
    c_perc = [{label_name: i, 'Percentage': c[i] / c.total() * 100.0, 'Appearances': c[i]} for i in c]
    c_df = pd.DataFrame.from_dict(c_perc)
    c_df['Percentage'] = c_df['Percentage'].transform(lambda x: '{:,.2f}%'.format(x))
    c_df = c_df.sort_values(by = ['Appearances'], ascending = False)
    return c_df

def count_occurrences(G):    
    node_types = []
    for n, attr in G.nodes(data=True):
        node_types.append(attr['type'])
        
    edge_types = []
    triplets = []
    for n1, n2, attr in G.edges(data=True):
        edge_type = attr['label']
        edge_types.append(edge_type)
        
        n1_type = G.nodes[n1]['type']
        n2_type = G.nodes[n2]['type']
        
        triplet = tuple([n1_type, edge_type, n2_type])
        triplets.append(triplet)
    
    node_types_df = get_occurrence_df(Counter(node_types), 'Node Type')
    
    edge_types_df = get_occurrence_df(Counter(edge_types), 'Edge Type')
    
    metapaths_df = get_occurrence_df(Counter(triplets), 'Metapath')
    
    return node_types_df, edge_types_df, metapaths_df

def get_shortest_path_len_drug_symptom(n1, n2, G):
    path_len = nx.shortest_path_length(G, n1, n2)
    return path_len

In [7]:
obj_measures_list = []

for dataset_prefix in DATASET_PREFIXES:
    node_types_per_expl = []
    edge_types_per_expl = []
    triplet_types_per_expl = []
    path_len_per_expl = []
    
    unique_pairs = []
    
    nr_expl = 0
    for expl_g, pair in zip(explanations_per_dataset[dataset_prefix], explanation_pairs_per_dataset[dataset_prefix]):
        node_types_df, edge_types_df, metapaths_df = count_occurrences(expl_g)
        
        nr_node_types = node_types_df['Node Type'].nunique()
        node_types_per_expl.append(nr_node_types)
        
        nr_edge_types = edge_types_df['Edge Type'].nunique()
        edge_types_per_expl.append(nr_edge_types)
        
        nr_triplet_types = metapaths_df['Metapath'].nunique()
        triplet_types_per_expl.append(nr_triplet_types)
        
        node_1, node_2 = pair
        shortest_path_len = get_shortest_path_len_drug_symptom(node_1, node_2, expl_g)
        path_len_per_expl.append(shortest_path_len)
        
        unique_pairs.append(tuple(pair))
        
        nr_expl += 1
    
    print('All drug-symptom pairs explained in the explanations found from dataset', dataset_prefix, set(unique_pairs))
    
    print(f'For the {nr_expl} explanations generated from dataset {dataset_prefix}')
    print(f'Average number of node types: {(sum(node_types_per_expl)/len(node_types_per_expl))}')
    print(f'Average number of edge types: {(sum(edge_types_per_expl)/len(edge_types_per_expl))}')
    print(f'Average number of triplets: {(sum(triplet_types_per_expl)/len(triplet_types_per_expl))}')
    print(f'Average shortest path length between drug and symptom pair: {(sum(path_len_per_expl)/len(path_len_per_expl))}')

    obj_measures_list.append({
        'disease prefix': DISEASE_PREFIX,
        'dataset prefix': dataset_prefix,
        'embedding': embedding_method,
        'seed': seeded_emb,
        'explanation overlap': expl_folders[dataset_prefix],
        'total explanations': nr_expl,
        'avg number of node types': sum(node_types_per_expl)/len(node_types_per_expl),
        'avg number of edge types': sum(edge_types_per_expl)/len(edge_types_per_expl),
        'avg number of triplets': sum(triplet_types_per_expl)/len(triplet_types_per_expl),
        'avg shortest path length between drug and symptom': sum(path_len_per_expl)/len(path_len_per_expl)
    })

All drug-symptom pairs explained in the explanations found from dataset prev {('Arthralgia 4498', 'aclarubicin 24')}
For the 9 explanations generated from dataset prev
Average number of node types: 3.0
Average number of edge types: 4.0
Average number of triplets: 5.0
Average shortest path length between drug and symptom pair: 2.0
All drug-symptom pairs explained in the explanations found from dataset restr {('Joint hyperflexibility 7114', 'ascorbic acid 5460'), ('Bruising susceptibility 6967', 'ascorbic acid 5460')}
For the 5 explanations generated from dataset restr
Average number of node types: 5.2
Average number of edge types: 5.2
Average number of triplets: 6.4
Average shortest path length between drug and symptom pair: 3.8


In [8]:
obj_measures_df = pd.DataFrame.from_dict(obj_measures_list)
obj_measures_df

Unnamed: 0,disease prefix,dataset prefix,embedding,seed,explanation overlap,total explanations,avg number of node types,avg number of edge types,avg number of triplets,avg shortest path length between drug and symptom
0,oi,prev,e2v,False,expl_9,9,3.0,4.0,5.0,2.0
1,oi,restr,e2v,False,expl_8,5,5.2,5.2,6.4,3.8


In [9]:
obj_measures_df.to_csv(f'../output/{DISEASE_PREFIX}/{DISEASE_PREFIX}_{dataset_prefix}_explanation_objective_measurements.csv', index=False)

In [10]:
for dataset_prefix in DATASET_PREFIXES:
    nodes = pd.read_csv(f'../output/{DISEASE_PREFIX}/{dataset_prefix}_{DISEASE_PREFIX}_indexed_nodes.csv')

    dataset_output_dir = os.path.join(curr_output_dir, f'{dataset_prefix}_{embedding_method}{fixed_emb}')
    print(dataset_output_dir)

    overlap_nr = expl_folders[dataset_prefix].replace('expl_', '')
    overlap_file = f'symptom_drug_pair_overlapping_{overlap_nr}_runs_{DISEASE_PREFIX}_{dataset_prefix}_{embedding_method}{fixed_emb}'
    print(f'Use overlapping pairs found in {overlap_file}')

    with open(f'{dataset_output_dir}/{overlap_file}.pkl', 'rb') as f:
        overlapping_symptom_drug_pairs = pickle.load(f)

    explanation_numbers = []
    
    for overlapping_symptom_drug_pair in overlapping_symptom_drug_pairs:

        symptom_id, drug_id = overlapping_symptom_drug_pair

        symptom_label = nodes.loc[nodes['id'] == symptom_id]['label'].values[0]
        drug_label = nodes.loc[nodes['id'] == drug_id]['label'].values[0]

        pair_dict = {'drug': drug_label, 'symptom ID': symptom_id, 'symptom': symptom_label}

        key_val = f'{drug_id} {symptom_id}'
        if key_val in complete_explanation_counts_per_dataset[dataset_prefix]:
            pair_dict['complete explanations'] = complete_explanation_counts_per_dataset[dataset_prefix][key_val]
        else:
            pair_dict['complete explanations'] = 0

        if key_val in incomplete_explanation_counts_per_dataset[dataset_prefix]:
            pair_dict['incomplete explanations'] = incomplete_explanation_counts_per_dataset[dataset_prefix][key_val]
        else:
            pair_dict['incomplete explanations'] = 0

        explanation_numbers.append(pair_dict)

    explanation_numbers_df = pd.DataFrame.from_dict(explanation_numbers)
    explanation_numbers_df.to_csv(f'../output/{DISEASE_PREFIX}/{dataset_prefix}_{embedding_method}{fixed_emb}/{expl_folders[dataset_prefix]}/{DISEASE_PREFIX}_{dataset_prefix}_{expl_folders[dataset_prefix]}_explanation_results.csv', index=False)

c:\Users\rzwart\Documents\GitHub\XAI-FO\output\oi\prev_e2v
Use overlapping pairs found in symptom_drug_pair_overlapping_9_runs_oi_prev_e2v


PermissionError: [Errno 13] Permission denied: '../output/oi/prev_e2v/expl_9/oi_prev_expl_9_explanation_results.csv'