In [118]:
import pandas as pd
import numpy as np
import _pickle as pickle
from collections import Counter, defaultdict
from scipy.stats import wasserstein_distance
import pymysql
from tqdm import tqdm
from statistics import mean
from datetime import datetime, timedelta

In [2]:
importance = pd.read_csv('../data/final_feature_importances.csv')

In [5]:
visit_probability = {}
for visit_id, prob in np.array(pd.read_csv('../data/rfr_model_depth_69_trees_190_preds.csv', header=None)):
    visit_probability[int(visit_id)] = prob
    
len(visit_probability)

1573113

In [15]:
demographic_variables = pickle.load(open('../data/demographic_variables.p', 'rb'))
datetime_variables = pickle.load(open('../data/datetime_variables.p', 'rb'))
diag_variables = pickle.load(open('../data/diag_variables.p', 'rb'))
all_variables = demograhic_variables + datetime_variables[:-1] + diag_variables
len(all_variables)

1637

In [17]:
all_demog = pd.read_csv('../data/all_visit_demographic_matrix.csv', header=None)


In [42]:
all_variable_cases = defaultdict(list)

for i in range(1,12):
    var = demographic_variables[i-1]
    all_variable_cases[var] = list(all_demog[0][all_demog[i]==1])

In [108]:
all_datetime = pd.read_csv('../data/all_visit_date_matrix.csv', header=None)

for i in range(1,len(datetime_variables)-1):
    var = str(datetime_variables[i-1])
    all_variable_cases[var] = list(all_datetime[0][all_datetime[i]==1])

In [26]:
# start tunnel first: ssh -f [uni]@mimir.dbmi.columbia.edu -L 3307:127.0.0.1:3306 -N

conn = pymysql.connect(host="127.0.0.1", 
                       user="", #uni
                       port = ,
                       passwd="", #sql password
                       db = "user_vr2430" ) #database
cur = conn.cursor()

In [44]:
cur.execute('''select distinct substring_index(icd10, '.', 1), visit_id
                from vfinal_1_predict_covid_conditions
                where substring_index(icd10, '.', 1) != '';''')

for icd10, visit_id in cur.fetchall():
    all_variable_cases[icd10].append(visit_id)
    

In [64]:
wasserstein_distance_results = {}

for i in tqdm(all_variable_cases):
    cases_prob = [visit_probability[k] for k in set(all_variable_cases[i])&set(visit_probability.keys())]
    non_cases = list(set(visit_probability.keys())-set(all_variable_cases[i]))
    non_cases_prob = [visit_probability[k] for k in non_cases]
    if mean(cases_prob) > mean(non_cases_prob):
        wasserstein_distance_results[i]= wasserstein_distance(cases_prob, non_cases_prob)
    else:
        wasserstein_distance_results[i]= -1*wasserstein_distance(cases_prob, non_cases_prob)

100%|██████████| 1636/1636 [58:22<00:00,  2.14s/it]


In [56]:
negative_training_set = pickle.load(open('../data/negative_training_set.p', 'rb'))
positive_training_set = pickle.load(open('../data/positive_training_set.p', 'rb'))
negative_eval_set = pickle.load(open('../data/negative_eval_set.p', 'rb'))
positive_eval_set = pickle.load(open('../data/positive_eval_set.p', 'rb'))

training_set = negative_training_set + positive_training_set
eval_set = negative_eval_set + positive_eval_set

In [69]:
wasserstein_distance_results_training = {}

for i in tqdm(all_variable_cases):
    cases_prob = [visit_probability[k] for k in set(all_variable_cases[i])&set(training_set)]
    non_cases = list((set(visit_probability.keys())-set(all_variable_cases[i]))&set(training_set))
    non_cases_prob = [visit_probability[k] for k in non_cases]
    if len(cases_prob) == 0 or len(non_cases_prob) == 0:
        continue
    if mean(cases_prob) > mean(non_cases_prob):
        wasserstein_distance_results_training[i]= wasserstein_distance(cases_prob, non_cases_prob)
    else:
        wasserstein_distance_results_training[i]= -1*wasserstein_distance(cases_prob, non_cases_prob)

100%|██████████| 1636/1636 [02:56<00:00,  9.29it/s]


In [70]:
wasserstein_distance_results_eval = {}

for i in tqdm(all_variable_cases):
    cases_prob = [visit_probability[k] for k in set(all_variable_cases[i])&set(eval_set)]
    non_cases = list((set(visit_probability.keys())-set(all_variable_cases[i]))&set(eval_set))
    non_cases_prob = [visit_probability[k] for k in non_cases]
    if len(cases_prob) == 0 or len(non_cases_prob) == 0:
        continue
    if mean(cases_prob) > mean(non_cases_prob):
        wasserstein_distance_results_eval[i]= wasserstein_distance(cases_prob, non_cases_prob)
    else:
        wasserstein_distance_results_eval[i]= -1*wasserstein_distance(cases_prob, non_cases_prob)

100%|██████████| 1636/1636 [02:56<00:00,  9.26it/s]


In [142]:
##feature name, importance, EMD_training, EMD_eval, EMD_full

table_data = []

for feature in tqdm(list(all_variable_cases.keys())):
    if list(importance.importance[importance.feature == str(feature)])[0] == 0:
        continue
    to_add = [feature, list(importance.importance[importance.feature == str(feature)])[0]]
    if feature in wasserstein_distance_results_training:
        to_add.append(wasserstein_distance_results_training[feature])
    else:
        to_add.append('-')
    if feature in wasserstein_distance_results_eval:
        to_add.append(wasserstein_distance_results_eval[feature])
    else:
        to_add.append('-')
    if feature in wasserstein_distance_results:
        to_add.append(wasserstein_distance_results[feature])
    else:
        to_add.append('-')
    if to_add[-3:] == ['-', '-', '-']:
        continue
    table_data.append(to_add)

100%|██████████| 1661/1661 [00:00<00:00, 2588.68it/s]


In [143]:
wasserstein_distance_results[datetime_variables[2]]

0.4209260198004161

In [144]:
table_columns = ['feature', 'importance', 'wasserstein_distance_training', 'wasserstein_distance_eval',
                 'wasserstein_distance_all_visits']

In [146]:
pd.DataFrame(table_data, columns=table_columns).sort_values('importance', ascending=False).to_csv('table_s3_all_features.csv', index=False)

In [149]:
pd.DataFrame(table_data, columns=table_columns).sort_values('importance', ascending=False).head(n=20).to_csv('table_2_20_feature_importance.csv', index=False)