In [None]:
import json
import os
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
result_dir = '/home/ubuntu/DIG/benchmarks/xgraph/result_jsons'

results = defaultdict(dict)
for dataset in os.listdir(result_dir):
    file = open(os.path.join(result_dir, dataset), 'r')
    for explainer, result in json.load(file).items():
        results[dataset][explainer] = next(iter(result.values()))

In [None]:
dataset_mapping = {
    'ba_2motifs': 'BA-2Motifs',
    'ba_shapes': 'BA-shapes',
    'bbbp': 'BBBP',
    'graph_sst2': 'Graph-SST2',
    'graph_sst5': 'Graph-SST5',
    'twitter': 'Graph-Twitter'
}
explainer_mapping = {
    'subgraphx': ('SubgraphX', 'o'),
    'gnn_explainer': ('GNNExplainer', 's'),
    'pgexplainer': ('PGExplainer', '^')
}

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(12, 4), sharex=True)
for i, (data_key, dataset) in enumerate(dataset_mapping.items()):
    ax = axs[i // 3, i % 3]
    result = results[f'{data_key}.json']
    for expl_key, (explainer, marker) in explainer_mapping.items():
        ax.plot(result[expl_key]['sparsity'], result[expl_key]['fidelity'], marker, label=explainer)
    ax.set_xlabel('.', color=(0, 0, 0, 0))
    ax.set_ylabel('.', color=(0, 0, 0, 0))
    ax.set_title(dataset)
    ax.set_xlim(0.4, 0.9)
fig.legend(*ax.get_legend_handles_labels(), loc='center', bbox_to_anchor=(0.5, 1.05), ncol=3)
fig.supxlabel('Sparsity')
fig.supylabel('Fidelity+')
fig.tight_layout(pad=0.25)

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(12, 4), sharex=True)
for i, (data_key, dataset) in enumerate(dataset_mapping.items()):
    ax = axs[i // 3, i % 3]
    result = results[f'{data_key}.json']
    for expl_key, (explainer, marker) in explainer_mapping.items():
        ax.plot(result[expl_key]['sparsity'], result[expl_key]['fidelity_inv'], marker, label=explainer)
    ax.set_xlabel('.', color=(0, 0, 0, 0))
    ax.set_ylabel('.', color=(0, 0, 0, 0))
    ax.set_title(dataset)
    ax.set_xlim(0.4, 0.9)
fig.legend(*ax.get_legend_handles_labels(), loc='center', bbox_to_anchor=(0.5, 1.05), ncol=3)
fig.supxlabel('Sparsity')
fig.supylabel('Fidelity-')
fig.tight_layout(pad=0.25)

In [None]:
log_dir = '/home/ubuntu/DIG/logs'

times = defaultdict(dict)
for explainer in os.listdir(log_dir):
    for data_key, dataset in dataset_mapping.items():
        lines = open(os.path.join(log_dir, explainer, f'{data_key}.log')).readlines()
        try:
            times[dataset][explainer] = (pd.to_datetime(lines[-1]) - pd.to_datetime(lines[0])).total_seconds()
        except:
            print(dataset, explainer)