In [12]:
import os
import json

import numpy as np
import matplotlib as plt
from rich import print as pprint
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score
from sklearn.metrics import auc

from truthful_counterfactuals.utils import EXPERIMENTS_PATH
from truthful_counterfactuals.metrics import threshold_error_reduction
from truthful_counterfactuals.visualization import plot_threshold_error_reductions
from truthful_counterfactuals.utils import latex_table, render_latex, latex_table_element_mean

PATH = os.getcwd()
RESULTS_PATH = os.path.join(EXPERIMENTS_PATH, 'results')
RESULTS_FILE = 'results.json'
MODEL_FILE = 'model.ckpt'

NUM_BINS = 50

In [22]:
result_map = {
    # == GAT RESULTS ==
    'random': {
        'model': 'GATv2',
        'method': 'Random',
        'paths': [
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__rand', 'logp_gat_1'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__rand', 'logp_gat_2'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__rand', 'logp_gat_3'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__rand', 'logp_gat_4'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__rand', 'logp_gat_5'),
        ]  
    },
    'gat_ens_cal': {
        'model': 'GATv2',
        'method': 'ENS(3)',
        'paths': [
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gat_cal_1'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gat_cal_2'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gat_cal_3'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gat_cal_4'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gat_cal_5'),
        ]
    },
    'gat_mve_cal': {
        'model': 'GATv2',
        'method': 'MVE',
        'paths': [
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gat_cal_1'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gat_cal_2'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gat_cal_3'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gat_cal_4'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gat_cal_5'),
        ]
    },
    'gat_swag_cal': {
        'model': 'GATv2',
        'method': 'SWAG',
        'paths': [
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__swag', 'logp_gat_cal_1'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__swag', 'logp_gat_cal_2'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__swag', 'logp_gat_cal_3'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__swag', 'logp_gat_cal_4'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__swag', 'logp_gat_cal_5'),
        ]
    },
    'gat_ens_mve_cal': {
        'model': 'GATv2',
        'method': 'ENS(3)+MVE',
        'paths': [
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens_mve', 'logp_gat_cal_1'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens_mve', 'logp_gat_cal_2'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens_mve', 'logp_gat_cal_3'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens_mve', 'logp_gat_cal_4'),
            os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens_mve', 'logp_gat_cal_5'),
        ]
    },
    # == GIN RESULTS ==
    # 'gin_ens': {
    #     'model': 'GIN',
    #     'method': 'ENS(3)',
    #     'paths': [
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_1'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_2'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_3'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_4'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_5'),
    #     ]
    # },
    # 'gin_ens_cal': {
    #     'model': 'GIN',
    #     'method': 'ENS(3)+Cal.',
    #     'paths': [
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_cal_1'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_cal_2'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_cal_3'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_cal_4'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__ens', 'logp_gin_cal_5'),
    #     ]
    # },
    # 'gin_mve': {
    #     'model': 'GIN',
    #     'method': 'MVE',
    #     'paths': [
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_1'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_2'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_3'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_4'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_5'),
    #     ]
    # },
    # 'gin_mve_cal': {
    #     'model': 'GIN',
    #     'method': 'MVE+Cal.',
    #     'paths': [
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_cal_1'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_cal_2'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_cal_3'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_cal_4'),
    #         os.path.join(RESULTS_PATH, 'quantify_uncertainty__mve', 'logp_gin_cal_5'),
    #     ]
    # }
}

def get_value(value: float | list):
    if isinstance(value, list):
        return value[0]
    return value

In [23]:
print('processing the results...')
for key, data in result_map.items():
    
    data['mae_values'] = []
    data['r2_values'] = []
    data['corr_values'] = []
    data['auc_mean_values'] = []
    data['auc_max_values'] = []
    
    print(f' * processing {key}')
    for path in data['paths']:
        
        results_path = os.path.join(path, RESULTS_FILE)
        with open(results_path, mode='r') as file:
            content = file.read()
            results = json.loads(content)
            
        out_true = [get_value(result['graph_labels']) for result in results]
        out_pred = [get_value(result['prediction']) for result in results]
            
        # calculating the prediction performance metrics (MAE & R2)
        mae_value = mean_absolute_error(out_true, out_pred)
        r2_value = r2_score(out_true, out_pred)
        data['mae_values'].append(mae_value)
        data['r2_values'].append(r2_value)
        
        # as a setup we then have to calculate the error between the prediction and the true value
        # and also get the uncertainty values from the results directly
        for result in results:
            result['error'] = abs(get_value(result['graph_labels']) - get_value(result['prediction']))
        
        errors = np.array([get_value(result['error']) for result in results])
        uncertainties = np.array([get_value(result['uncertainty']) for result in results])
        
        # then we can calculate the correlation between the error and the uncertainty
        corr_value = np.corrcoef(errors, uncertainties)[0, 1]
        data['corr_values'].append(corr_value)
        
        # finally we can use the error and uncertainty values to calculate the EUT-AUC
        ths, rds = threshold_error_reduction(uncertainties, errors, error_func=np.mean, num_bins=NUM_BINS)
        auc_mean_value = auc(ths, rds)
        data['auc_mean_values'].append(auc_mean_value)
        
        ths, rds =  threshold_error_reduction(uncertainties, errors, error_func=np.max, num_bins=NUM_BINS)
        auc_max_value = auc(ths, rds)
        data['auc_max_values'].append(auc_max_value)
        
    pprint(data)

processing the results...
 * processing random


 * processing gat_ens_cal


 * processing gat_mve_cal


 * processing gat_swag_cal


 * processing gat_ens_mve_cal


In [24]:
print('summarizing the results...')

rows = []
for key, data in result_map.items():
    row = [
        data['model'],
        data['method'],
        data['mae_values'],
        data['r2_values'],
        data['corr_values'],
        data['auc_mean_values'],
        data['auc_max_values'],
    ]
    rows.append(row)

_, content = latex_table(
    column_names=['Model', 'Method', 'MAE', r'$R^2$', r'$\rho$', r'$\text{UER-AUC}_{\text{mean}}$', r'$\text{UER-AUC}_{\text{max}}$'],
    rows=rows,
)

tex_path = os.path.join(PATH, 'table_1.tex')
with open(tex_path, mode='w') as file:
    file.write(content)
    
pdf_path = os.path.join(PATH, 'table_1.pdf')
render_latex({'content': content}, pdf_path)



summarizing the results...


In [25]:
from IPython.display import display, IFrame, display_pdf

display_pdf(pdf_path)