In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import glob

import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '../DBN_model_learning/')

from pickleObjects import *

# from sklearn.metrics import roc_auc_score
# from sklearn.metrics import average_precision_score, precision_score, recall_score, f1_score, confusion_matrix, brier_score_loss

In [None]:
# defining paths
structures_path = "../RAUS/FullNetwork/"

models_struct_directories = [
    model_name
    for model_name in glob.glob(structures_path + "*")
    if "no_race" in model_name and "count" not in model_name
    and "UCLA" in model_name
    # and "PSJH" in model_name
    # and "Combined" in model_name
]

# TODO: loop model directories
for models_struct_directory in tqdm(models_struct_directories):
    print(models_struct_directory)

    model_name = models_struct_directory.replace(structures_path, "")

In [None]:
model_name

In [None]:
## reading sensitivity analysis dictionary
sens_data_dict = loadObjects(
        "../Data/genie_datasets/DBN_predictions/Results/sens_analysis_results/"
        + "aceiarb_with_med_"
        + model_name
        + "__sens_analysis_results")

In [None]:
len(sens_data_dict["Year 1"]['time_zero_upcr_mean']['predictions_list'])

# Pulling categories average risk for epoch 0

In [None]:
def getSignVars(vars, data_dict, epoch_num):
    var_names = vars
    sign_vars = []
    sign_vars_categories = []

    for var_name in var_names:
        var_indx = var_names.index(var_name)
        categories = data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["categories"]
        num_combs = len(data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["combinations"])

        KS_test_signs = []

        for comb_num in range(num_combs):
            KS_test_signs.append(data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["combinations"][comb_num]["KS_test_sign"])

        if np.any(KS_test_signs):
            sign_vars.append(var_name)
            sign_vars_categories.append(categories)
    return sign_vars, sign_vars_categories

def getStats(vars, data_dict, epoch_num):
    var_names = vars

    var_avgs, var_stds, var_avgs_diffs = [], [], []
    for var_name in var_names:
        var_indx = var_names.index(var_name)
        categories = data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["categories"]
        num_combs = len(data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["combinations"])

        pred_avgs, pred_stds, KS_test_signs = [], [], []

        for pred_indx in range(len(categories)):
            pred_avgs.append(np.mean(data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["predictions_list"][pred_indx]))
            pred_stds.append(np.mean(data_dict["Year "+str(1+epoch_num)][var_names[var_indx]]["predictions_list"][pred_indx]))

        # get difference in average min from max
        var_avgs_diffs.append(np.abs(np.min(pred_avgs)-np.max(pred_avgs)))
        
        var_avgs.append(pred_avgs)
        var_stds.append(pred_stds)
    
    return var_avgs, var_stds, var_avgs_diffs

In [None]:
var_names = list(sens_data_dict["Year 1"].keys())
threshold_year1 = sens_data_dict['Year 1'][var_names[0]]['threshold']

sign_vars_year1, sign_vars_categories_year1 = getSignVars(var_names, sens_data_dict, epoch_num=0)

var_avgs_year1, var_stds_year1, var_avgs_diff_year1 = getStats(sign_vars_year1, sens_data_dict, epoch_num=0)

# picking risk difference larger than 1%
sign_vars_year1 = list(np.array(sign_vars_year1)[np.array(var_avgs_diff_year1)>0.0001])
sign_vars_categories_year1 = list(np.array(sign_vars_categories_year1)[np.array(var_avgs_diff_year1)>0.0001])
var_avgs_year1 = list(np.array(var_avgs_year1)[np.array(var_avgs_diff_year1)>0.0001])
var_stds_year1 = list(np.array(var_stds_year1)[np.array(var_avgs_diff_year1)>0.0001])
var_avgs_diff_year1 = list(np.array(var_avgs_diff_year1)[np.array(var_avgs_diff_year1)>0.0001])

# picking top 3
sign_vars_year1 = list(np.array(sign_vars_year1)[np.argsort(var_avgs_diff_year1)][::-1][:20])
sign_vars_categories_year1 = list(np.array(sign_vars_categories_year1)[np.argsort(var_avgs_diff_year1)][::-1][:20])
var_avgs_year1 = list(np.array(var_avgs_year1)[np.argsort(var_avgs_diff_year1)][::-1][:20])
var_stds_year1 = list(np.array(var_stds_year1)[np.argsort(var_avgs_diff_year1)][::-1][:20])
var_avgs_diff_year1 = list(np.array(var_avgs_diff_year1)[np.argsort(var_avgs_diff_year1)][::-1][:20])

print(sign_vars_year1)

In [None]:
for epoch_num in range(6):
    print("Power: ",sens_data_dict["Year "+str(1+epoch_num)][var_names[0]]["Power"])

In [None]:
sign_vars_year1

# Epoch 1

In [None]:
var_names = list(sens_data_dict["Year 2"].keys())
threshold_year2 = sens_data_dict['Year 2'][var_names[0]]['threshold']

sign_vars_year2, sign_vars_categories_year2 = getSignVars(var_names, sens_data_dict, epoch_num=1)

var_avgs_year2, var_stds_year2, var_avgs_diff_year2 = getStats(sign_vars_year2, sens_data_dict, epoch_num=1)

sign_vars_year2 = list(np.array(sign_vars_year2)[np.array(var_avgs_diff_year2)>0.0001])
sign_vars_categories_year2 = list(np.array(sign_vars_categories_year2)[np.array(var_avgs_diff_year2)>0.0001])
var_avgs_year2 = list(np.array(var_avgs_year2)[np.array(var_avgs_diff_year2)>0.0001])
var_stds_year2 = list(np.array(var_stds_year2)[np.array(var_avgs_diff_year2)>0.0001])
var_avgs_diff_year2 = list(np.array(var_avgs_diff_year2)[np.array(var_avgs_diff_year2)>0.0001])

# picking top 3
sign_vars_year2 = list(np.array(sign_vars_year2)[np.argsort(var_avgs_diff_year2)][::-1][:20])
sign_vars_categories_year2 = list(np.array(sign_vars_categories_year2)[np.argsort(var_avgs_diff_year2)][::-1][:20])
var_avgs_year2 = list(np.array(var_avgs_year2)[np.argsort(var_avgs_diff_year2)][::-1][:20])
var_stds_year2 = list(np.array(var_stds_year2)[np.argsort(var_avgs_diff_year2)][::-1][:20])
var_avgs_diff_year2 = list(np.array(var_avgs_diff_year2)[np.argsort(var_avgs_diff_year2)][::-1][:20])

print(sign_vars_year2)

In [None]:
sign_vars_year2

In [None]:
set(sign_vars_year1).intersection(set(sign_vars_year2))

# Epoch 2

In [None]:
var_names = list(sens_data_dict["Year 3"].keys())
threshold_year3 = sens_data_dict['Year 3'][var_names[0]]['threshold']

sign_vars_year3, sign_vars_categories_year3 = getSignVars(var_names, sens_data_dict, epoch_num=2)

var_avgs_year3, var_stds_year3, var_avgs_diff_year3 = getStats(sign_vars_year3, sens_data_dict, epoch_num=2)

sign_vars_year3 = list(np.array(sign_vars_year3)[np.array(var_avgs_diff_year3)>0.0001])
sign_vars_categories_year3 = list(np.array(sign_vars_categories_year3)[np.array(var_avgs_diff_year3)>0.0001])
var_avgs_year3 = list(np.array(var_avgs_year3)[np.array(var_avgs_diff_year3)>0.0001])
var_stds_year3 = list(np.array(var_stds_year3)[np.array(var_avgs_diff_year3)>0.0001])
var_avgs_diff_year3 = list(np.array(var_avgs_diff_year3)[np.array(var_avgs_diff_year3)>0.0001])

# picking top 3
sign_vars_year3 = list(np.array(sign_vars_year3)[np.argsort(var_avgs_diff_year3)][::-1][:20])
sign_vars_categories_year3 = list(np.array(sign_vars_categories_year3)[np.argsort(var_avgs_diff_year3)][::-1][:20])
var_avgs_year3 = list(np.array(var_avgs_year3)[np.argsort(var_avgs_diff_year3)][::-1][:20])
var_stds_year3 = list(np.array(var_stds_year3)[np.argsort(var_avgs_diff_year3)][::-1][:20])
var_avgs_diff_year3 = list(np.array(var_avgs_diff_year3)[np.argsort(var_avgs_diff_year3)][::-1][:20])


In [None]:
sign_vars_year3

In [None]:
set(sign_vars_year2).intersection(set(sign_vars_year3))

# Epoch 3

In [None]:
var_names = list(sens_data_dict["Year 4"].keys())
threshold_year4 = sens_data_dict['Year 4'][var_names[0]]['threshold']

sign_vars_year4, sign_vars_categories_year4 = getSignVars(var_names, sens_data_dict, epoch_num=3)

var_avgs_year4, var_stds_year4, var_avgs_diff_year4 = getStats(sign_vars_year4, sens_data_dict, epoch_num=3)

sign_vars_year4 = list(np.array(sign_vars_year4)[np.array(var_avgs_diff_year4)>0.0001])
sign_vars_categories_year4 = list(np.array(sign_vars_categories_year4)[np.array(var_avgs_diff_year4)>0.0001])
var_avgs_year4 = list(np.array(var_avgs_year4)[np.array(var_avgs_diff_year4)>0.0001])
var_stds_year4 = list(np.array(var_stds_year4)[np.array(var_avgs_diff_year4)>0.0001])
var_avgs_diff_year4 = list(np.array(var_avgs_diff_year4)[np.array(var_avgs_diff_year4)>0.0001])

# picking top 4
sign_vars_year4 = list(np.array(sign_vars_year4)[np.argsort(var_avgs_diff_year4)][::-1][:20])
sign_vars_categories_year4 = list(np.array(sign_vars_categories_year4)[np.argsort(var_avgs_diff_year4)][::-1][:20])
var_avgs_year4 = list(np.array(var_avgs_year4)[np.argsort(var_avgs_diff_year4)][::-1][:20])
var_stds_year4 = list(np.array(var_stds_year4)[np.argsort(var_avgs_diff_year4)][::-1][:20])
var_avgs_diff_year4 = list(np.array(var_avgs_diff_year4)[np.argsort(var_avgs_diff_year4)][::-1][:20])

print(sign_vars_year4)

In [None]:
sign_vars_year4

In [None]:
set(sign_vars_year3).intersection(set(sign_vars_year4))

# Epoch 4

In [None]:
var_names = list(sens_data_dict["Year 5"].keys())
threshold_year5 = sens_data_dict['Year 5'][var_names[0]]['threshold']

sign_vars_year5, sign_vars_categories_year5 = getSignVars(var_names, sens_data_dict, epoch_num=4)

var_avgs_year5, var_stds_year5, var_avgs_diff_year5 = getStats(sign_vars_year5, sens_data_dict, epoch_num=4)

sign_vars_year5 = list(np.array(sign_vars_year5)[np.array(var_avgs_diff_year5)>0.0001])
sign_vars_categories_year5 = list(np.array(sign_vars_categories_year5)[np.array(var_avgs_diff_year5)>0.0001])
var_avgs_year5 = list(np.array(var_avgs_year5)[np.array(var_avgs_diff_year5)>0.0001])
var_stds_year5 = list(np.array(var_stds_year5)[np.array(var_avgs_diff_year5)>0.0001])
var_avgs_diff_year5 = list(np.array(var_avgs_diff_year5)[np.array(var_avgs_diff_year5)>0.0001])

# picking top 5
sign_vars_year5 = list(np.array(sign_vars_year5)[np.argsort(var_avgs_diff_year5)][::-1][:20])
sign_vars_categories_year5 = list(np.array(sign_vars_categories_year5)[np.argsort(var_avgs_diff_year5)][::-1][:20])
var_avgs_year5 = list(np.array(var_avgs_year5)[np.argsort(var_avgs_diff_year5)][::-1][:20])
var_stds_year5 = list(np.array(var_stds_year5)[np.argsort(var_avgs_diff_year5)][::-1][:20])
var_avgs_diff_year5 = list(np.array(var_avgs_diff_year5)[np.argsort(var_avgs_diff_year5)][::-1][:20])

print(sign_vars_year5)

In [None]:
sign_vars_year5

In [None]:
set(sign_vars_year4).intersection(set(sign_vars_year5))

# Epoch 5

In [None]:
var_names = list(sens_data_dict["Year 6"].keys())
threshold_year6 = sens_data_dict['Year 6'][var_names[0]]['threshold']

sign_vars_year6, sign_vars_categories_year6 = getSignVars(var_names, sens_data_dict, epoch_num=5)

var_avgs_year6, var_stds_year6, var_avgs_diff_year6 = getStats(sign_vars_year6, sens_data_dict, epoch_num=5)

sign_vars_year6 = list(np.array(sign_vars_year6)[np.array(var_avgs_diff_year6)>0.0001])
sign_vars_categories_year6 = list(np.array(sign_vars_categories_year6)[np.array(var_avgs_diff_year6)>0.0001])
var_avgs_year6 = list(np.array(var_avgs_year6)[np.array(var_avgs_diff_year6)>0.0001])
var_stds_year6 = list(np.array(var_stds_year6)[np.array(var_avgs_diff_year6)>0.0001])
var_avgs_diff_year6 = list(np.array(var_avgs_diff_year6)[np.array(var_avgs_diff_year6)>0.0001])

# picking top 6
sign_vars_year6 = list(np.array(sign_vars_year6)[np.argsort(var_avgs_diff_year6)][::-1][:20])
sign_vars_categories_year6 = list(np.array(sign_vars_categories_year6)[np.argsort(var_avgs_diff_year6)][::-1][:20])
var_avgs_year6 = list(np.array(var_avgs_year6)[np.argsort(var_avgs_diff_year6)][::-1][:20])
var_stds_year6 = list(np.array(var_stds_year6)[np.argsort(var_avgs_diff_year6)][::-1][:20])
var_avgs_diff_year6 = list(np.array(var_avgs_diff_year6)[np.argsort(var_avgs_diff_year6)][::-1][:20])

print(sign_vars_year6)

In [None]:
sign_vars_year6

In [None]:
set(sign_vars_year5).intersection(set(sign_vars_year6))

# Plots

In [None]:
def getSortedCatsStats(var_name, categories, pred_avgs, pred_stds):

    float_categories = []
    for category in categories:
        category = category.replace("S_","")

        if "__" in category:
            values = category.split("___")
        elif "s_" in category and "minu" not in category:
            values = ["0", category.replace("s_","")]
        elif "le_" in category:
            values = [category.replace("le_",""), "1000000"]
        else:
            values = (category)


        if len(values)>1 and not isinstance(values, str):  # strings have length > 0   
            values = tuple([np.round(float(val.replace("_",".")),1) for val in values])
        else:
            pass

        float_categories.append(values)

    sort_indices = [i for i, x in sorted(enumerate(float_categories), key=lambda x: x[1])]
    sort_float_cats = [x for i, x in sorted(enumerate(float_categories), key=lambda x: x[1])]
    sorted_pred_avgs = [pred_avgs[sort_index] for sort_index in sort_indices]
    sorted_pred_stds = [pred_stds[sort_index] for sort_index in sort_indices]
    return sort_float_cats, sorted_pred_avgs, sorted_pred_stds

In [None]:
import matplotlib.pyplot as plt

def getSensPlot(var_name, sort_float_cats,sorted_pred_avgs,sorted_pred_stds,threshold):

    fig = plt.figure(figsize=(30,5))
    x = [str(val) for val in sort_float_cats]
    y = sorted_pred_avgs
    yerr = sorted_pred_stds

    plt.errorbar(x, y, yerr=yerr, label=var_name)
    plt.axhline(y = threshold, color = 'r', linestyle = '-',label="Threshold")
    plt.title(var_name)
    plt.xlabel("range or categories")
    plt.ylabel("Average porbability of 40\% decline")
    plt.grid(True)
    plt.legend()

    return;

In [None]:
for var_indx,var_name in enumerate(sign_vars_year1):
    categories = sign_vars_categories_year1[var_indx]
    pred_avgs, pred_stds = var_avgs_year1[var_indx], var_stds_year1[var_indx]
    sort_float_cats, sorted_pred_avgs, sorted_pred_stds = getSortedCatsStats(var_name,categories, pred_avgs, pred_stds)

    getSensPlot(var_name, sort_float_cats, sorted_pred_avgs, sorted_pred_stds, threshold_year1)



In [None]:
for var_indx,var_name in enumerate(sign_vars_year2):
    categories = sign_vars_categories_year2[var_indx]
    pred_avgs, pred_stds = var_avgs_year2[var_indx], var_stds_year2[var_indx]
    sort_float_cats, sorted_pred_avgs, sorted_pred_stds = getSortedCatsStats(var_name,categories, pred_avgs, pred_stds)

    getSensPlot(var_name, sort_float_cats, sorted_pred_avgs, sorted_pred_stds,threshold_year2)

var_name

In [None]:
for var_indx,var_name in enumerate(sign_vars_year3):
    categories = sign_vars_categories_year3[var_indx]
    pred_avgs, pred_stds = var_avgs_year3[var_indx], var_stds_year3[var_indx]
    sort_float_cats, sorted_pred_avgs, sorted_pred_stds = getSortedCatsStats(var_name,categories, pred_avgs, pred_stds)

    getSensPlot(var_name, sort_float_cats, sorted_pred_avgs, sorted_pred_stds, threshold_year3)

var_name

In [None]:
for var_indx,var_name in enumerate(sign_vars_year4):
    categories = sign_vars_categories_year4[var_indx]
    pred_avgs, pred_stds = var_avgs_year4[var_indx], var_stds_year4[var_indx]
    sort_float_cats, sorted_pred_avgs, sorted_pred_stds = getSortedCatsStats(var_name,categories, pred_avgs, pred_stds)

    getSensPlot(var_name, sort_float_cats, sorted_pred_avgs, sorted_pred_stds, threshold_year4)

var_name

In [None]:
for var_indx,var_name in enumerate(sign_vars_year5):
    categories = sign_vars_categories_year5[var_indx]
    pred_avgs, pred_stds = var_avgs_year5[var_indx], var_stds_year5[var_indx]
    sort_float_cats, sorted_pred_avgs, sorted_pred_stds = getSortedCatsStats(var_name,categories, pred_avgs, pred_stds)

    getSensPlot(var_name, sort_float_cats, sorted_pred_avgs, sorted_pred_stds, threshold_year5)

var_name

In [None]:
for var_indx,var_name in enumerate(sign_vars_year6):
    categories = sign_vars_categories_year6[var_indx]
    pred_avgs, pred_stds = var_avgs_year6[var_indx], var_stds_year6[var_indx]
    sort_float_cats, sorted_pred_avgs, sorted_pred_stds = getSortedCatsStats(var_name,categories, pred_avgs, pred_stds)

    getSensPlot(var_name, sort_float_cats, sorted_pred_avgs, sorted_pred_stds, threshold_year6)

var_name