In [219]:
# store start time to get execution time of entire script
import time
start_time = time.time()

In [220]:
# helper functions for displaying table data

import numpy as np
from IPython.display import display_html

# n is the number of columns to display data in
def display_side_by_side(series_obj, n):
    df = pd.DataFrame(series_obj)
    partition = int(round(len(df) / n))
    lower_bound = 0
    upper_bound = partition
    args = []
    for i in range(n):
        args.append(df[lower_bound:upper_bound])
        lower_bound += partition
        upper_bound += partition
    helper(args)

def helper(args):
    html_str=''
    for df in args:
        html_str+=df.to_html()
    display_html(html_str.replace('table','table style="display:inline"'),raw=True)

In [221]:
# helper function for plotting out ground truth curves

import matplotlib.pyplot as plt

def get_ground_truth(data):
    relapsed = data[data.Illicit_Cens5 == 0]
    counts = relapsed['Illicit_Days5'].value_counts()
    counts = counts.to_dict()
    temp = [len(data)] * 365
    labels = list(range(365))
    for i in range(365):
        labels[i] += 1
    total = 0
    errors = []
    for i in range(365):
        try:
            temp[i] = temp[i] - counts[i+1] - total
            total = total + counts[i+1]
        except KeyError:
            errors.append(i)

    for ele in sorted(errors, reverse = False):
        if ele != 0:
            temp[ele] = temp[ele-1]
        else:
             temp[0] = len(data)
    temp = [x / len(data) for x in temp]
    return labels, temp

In [222]:
from sklearn.model_selection import cross_validate
from sksurv.ensemble import GradientBoostingSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxnetSurvivalAnalysis

def run_models(X, y, label):
    gbsa = GradientBoostingSurvivalAnalysis()
    scores = cross_validate(gbsa, X, y, cv=5, n_jobs=-1)
    gbsa_score = scores['test_score'].mean()
    print('RF Boosted score:', gbsa_score)
    
    gbsa = GradientBoostingSurvivalAnalysis()
    gbsa.fit(X, y)
    
    rsf = RandomSurvivalForest()
    scores = cross_validate(rsf, X, y, cv=5, n_jobs=-1)
    rsf_score = scores['test_score'].mean()
    print('RF score:', rsf_score)
    
    
    rsf = RandomSurvivalForest(n_jobs=-1)
    rsf.fit(X, y)
    
    # l1_ratio = 1 adjusts model to implement LASSO method for penalties
    # fit_baseline_model = True allows us to create survival/hazard plots after model is fit
    rcr = CoxnetSurvivalAnalysis(fit_baseline_model=True, l1_ratio=1)

    scores = cross_validate(rcr, X, y, cv=5, n_jobs=-1)
    rcr_score = scores['test_score'].mean()
    print('Lasso score:', rcr_score)
    
    rcr = CoxnetSurvivalAnalysis(fit_baseline_model=True, l1_ratio=1)
    rcr.fit(X, y)
    
    # concordance index
    scores = {'Model': ['Random Forest Boosted','Random Forest','Lasso'], label: [gbsa_score,rsf_score,rcr_score]}
    concordance = pd.DataFrame(data=scores)
    
    # return scores and models
    return concordance, gbsa, rsf, rcr

In [223]:
def get_survival_graph(rsf, rcr, X, Y, label, filename):
    pred_surv_rsf = rsf.predict_survival_function(X)
    pred_surv_rcr = rcr.predict_survival_function(X)
    
    # display survival plot
    plt.suptitle(label)
    plt.plot(np.mean([person for person in pred_surv_rsf], axis=0), label='RF')
    plt.plot(np.mean([person.y for person in pred_surv_rcr], axis=0), label='Lasso')
    labels, temp = get_ground_truth(Y)
    plt.plot(labels, temp, label='Ground Truth')
    plt.legend()
    plt.xlim(0, 365)
    plt.xticks(np.arange(0, 365, step=50))
    plt.yticks(np.arange(0, 1.1, step=0.1))
    plt.savefig(filename)
        
    plt.show()

In [224]:
def get_feature_importance(features, gbsa, rcr, label):
    # feature importances from Boosted Random Forest
    feature_importance_rf = pd.DataFrame({'Feature':features, label:gbsa.feature_importances_,})
    feature_importance_rf.sort_values(by=[label], ascending=False, inplace=True)
    feature_importance_rf = feature_importance_rf.nlargest(10,[label]) # keep top 10 features

    # feature importances from Lasso
    feature_importance_lasso = pd.DataFrame({'Feature':features,
                                                  label:np.average(rcr.coef_, weights=rcr.alphas_, axis = 1),})
    feature_importance_lasso[label + '_abs'] = np.absolute(feature_importance_lasso[label])
    feature_importance_lasso = feature_importance_lasso.nlargest(10,[label + '_abs']) # keep top 10 features
    
    return feature_importance_rf, feature_importance_lasso

Survival Analysis by Age

In [225]:
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
import csv

df = pd.read_csv('data/data_superset.csv')
df.head()

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Unnamed: 0.1.1,ID,State,City,agyaddr,Illicit_Days5,Illicit_Cens5,adol,xobsyr_0,female_cd,nonwhite_cd,unemplmt_cd,prsatx_cd,gvsg_cd,CWSg_0_cd,srprobg_cd,dssg_0_cd,epsg_0_cd,adhdg_0_cd,cdsg_0_cd,cjsig_0_cd,lrig_0_cd,srig_0_cd,SESg_0_cd,r4ag_0_cd,nonillicit_flag,primsev_cd_1,primsev_cd_2,primsev_cd_3,primsev_cd_4,primsev_cd_5,primsev_cd_6,Address,lat,lng,Geo_FIPS,murder_numg,hcd,%_U18g,%_female_householdg,%_unemployedg,%_public_assistanceg,%_povertyg
0,0,0,0,23223,FL,Miami,2140 South Dixie Hwy,365,0,0,2010,0,1,0,0,0,0,1,1,0,0,0,0,1,1,0,1,1,0,0,1,0,0,0,"2140 South Dixie Hwy, Miami, FL",25.743113,-80.228303,12086.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0
1,1,1,1,857,OH,Cleveland,1276 West Third St. #400,365,0,1,2005,0,0,0,0,2,0,0,1,1,1,1,0,1,2,0,1,1,1,0,0,0,0,0,"1276 West Third St. #400, Cleveland, OH",41.501028,-81.697772,39035.0,0.0,,,,,,
2,2,2,2,929,OH,Cleveland,1276 West Third St. #400,354,0,1,2006,0,0,0,0,1,0,1,0,1,0,1,1,0,1,0,2,1,1,0,0,0,0,0,"1276 West Third St. #400, Cleveland, OH",41.501028,-81.697772,39035.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0
3,3,3,3,951,OH,Cleveland,1276 West Third St. #400,365,0,1,2006,0,0,0,0,0,0,1,0,0,0,0,1,0,1,0,2,1,0,0,1,0,0,0,"1276 West Third St. #400, Cleveland, OH",41.501028,-81.697772,39035.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0
4,4,4,4,1032,OH,Cleveland,1276 West Third St. #400,365,0,1,2006,0,0,0,0,2,0,1,1,1,1,1,0,2,1,0,2,1,0,0,1,0,0,0,"1276 West Third St. #400, Cleveland, OH",41.501028,-81.697772,39035.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0


In [226]:
# subset to patients who have county level murder and socioeconomic data available
df.dropna(subset=['murder_numg','hcd'], inplace=True)
df.shape

(10034, 45)

In [227]:
# drop unnecessary columns
cols_to_drop = ['Address','lat','lng','Geo_FIPS','xobsyr_0','Unnamed: 0','Unnamed: 0.1','Unnamed: 0.1.1',
                'ID','State','City','agyaddr']

# uncomment to get CONTROL statistics
cols_to_drop = cols_to_drop + ['%_U18g','%_female_householdg','%_unemployedg','%_public_assistanceg','%_povertyg'] # 'murder_numg','hcd',

df.drop(columns=cols_to_drop, inplace=True)
df = df.astype(int)
df.shape

(10034, 28)

In [228]:
df = df[df.nonillicit_flag == 0] # subset to only the illicit cases
# df.drop(columns=['nonillicit_flag'], inplace=True) # if not used to subset, remove feature since its redundant

In [229]:
"""selected_features = ['nonwhite_cd', 'gvsg_cd', 'CWSg_0_cd', 'srprobg_cd', 'dssg_0_cd',
       'adhdg_0_cd', 'cdsg_0_cd', 'cjsig_0_cd', 'srig_0_cd', 'SESg_0_cd',
       'r4ag_0_cd', 'primsev_cd_4', 'primsev_cd_5', 'primsev_cd_6',
       'murder_numg', '%_U18g', '%_female_householdg','Illicit_Days5','Illicit_Cens5','adol']

df = df[selected_features]"""

"selected_features = ['nonwhite_cd', 'gvsg_cd', 'CWSg_0_cd', 'srprobg_cd', 'dssg_0_cd',\n       'adhdg_0_cd', 'cdsg_0_cd', 'cjsig_0_cd', 'srig_0_cd', 'SESg_0_cd',\n       'r4ag_0_cd', 'primsev_cd_4', 'primsev_cd_5', 'primsev_cd_6',\n       'murder_numg', '%_U18g', '%_female_householdg','Illicit_Days5','Illicit_Cens5','adol']\n\ndf = df[selected_features]"

In [230]:
df.shape

(2627, 28)

In [231]:
df.head()

Unnamed: 0,Illicit_Days5,Illicit_Cens5,adol,female_cd,nonwhite_cd,unemplmt_cd,prsatx_cd,gvsg_cd,CWSg_0_cd,srprobg_cd,dssg_0_cd,epsg_0_cd,adhdg_0_cd,cdsg_0_cd,cjsig_0_cd,lrig_0_cd,srig_0_cd,SESg_0_cd,r4ag_0_cd,nonillicit_flag,primsev_cd_1,primsev_cd_2,primsev_cd_3,primsev_cd_4,primsev_cd_5,primsev_cd_6,murder_numg,hcd
29,166,0,1,0,0,0,0,2,0,1,2,1,2,2,2,2,2,0,0,0,0,1,0,0,0,0,0,0
56,92,0,1,0,0,0,0,2,0,1,1,0,2,2,2,2,2,2,0,0,0,1,0,0,0,0,0,1
57,168,0,1,0,0,0,1,2,0,1,1,1,1,2,2,1,1,2,0,0,0,1,0,0,0,0,0,0
59,192,0,1,0,0,0,0,1,0,1,1,0,0,2,2,1,0,0,0,0,0,1,0,0,0,0,0,0
61,351,0,1,0,0,0,0,0,0,1,0,1,1,1,2,1,2,0,0,0,0,1,0,0,0,0,0,0


Full Population Survival Analysis

In [232]:
from sklearn.model_selection import train_test_split
from sksurv.util import Surv

predictor_var = 'Illicit_Days5'
censoring_var = 'Illicit_Cens5'

X = df.copy()
Y = X[[censoring_var, predictor_var]]
X.drop(columns=[censoring_var, predictor_var], inplace=True)
y = Surv.from_arrays(Y[censoring_var], Y[predictor_var]) # structured array to ensure correct censoring

print(X.shape, y.shape)

(2627, 26) (2627,)


In [233]:
%%time
full_concordance, gbsa, rsf, rcr = run_models(X, y, 'ALL')

RF Boosted score: 0.676739556540572
RF score: 0.6781947918518874
Lasso score: 0.6890468160091261
CPU times: user 12.8 s, sys: 1.19 s, total: 14 s
Wall time: 48.4 s


In [234]:
rcr.penalty_factor_

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [202]:
get_survival_graph(rsf, rcr, X, Y, 'Survival: All Ages','graphs/survival_all.png')

ValueError: Number of features of the model must match the input. Model n_features is 23 and input n_features is 25.

Adolescent Survival Analysis

In [None]:
X = df[df.adol == 1]
Y = X[[censoring_var, predictor_var]]
X.drop(columns=[censoring_var, predictor_var], inplace=True)

y = Surv.from_arrays(Y[censoring_var], Y[predictor_var]) # structured array to ensure correct censoring

print(X.shape, y.shape)

In [None]:
%%time
adol_concordance, gbsa, rsf, rcr = run_models(X, y, 'ADOL')

In [None]:
get_survival_graph(rsf, rcr, X, Y, 'Survival: Adolescents', 'graphs/survival_adol.png')

In [None]:
adol_feature_importance_rf, adol_feature_importance_lasso = get_feature_importance(X.columns, gbsa, rcr, 'ADOL')

In [None]:
X = df[df.adol == 0]
Y = X[[censoring_var, predictor_var]]
X.drop(columns=[censoring_var, predictor_var], inplace=True)

y = Surv.from_arrays(Y[censoring_var], Y[predictor_var]) # structured array to ensure correct censoring

print(X.shape, y.shape)

In [None]:
%%time
non_adol_concordance, gbsa, rsf, rcr = run_models(X, y, 'NON-ADOL')

In [None]:
get_survival_graph(rsf, rcr, X, Y, 'Survival: Non-Adolescents', 'graphs/survival_non_adol.png')

In [None]:
non_adol_feature_importance_rf, non_adol_feature_importance_lasso = get_feature_importance(X.columns, gbsa, rcr, 'NON-ADOL')

Overall Statistics

In [None]:
overall_concordance = pd.concat([adol_concordance, non_adol_concordance['NON-ADOL'], 
                                 full_concordance['ALL']], axis=1)
pd.DataFrame(data=overall_concordance)

In [None]:
overall_feature_importance_lasso = pd.merge(adol_feature_importance_lasso, non_adol_feature_importance_lasso,
                                            on='Feature', how='outer')
overall_feature_importance_lasso.fillna(0, inplace=True)
display_side_by_side(overall_feature_importance_lasso, 2)

In [None]:
# feature importance for lasso across all ages
df = pd.DataFrame({'ADOL': overall_feature_importance_lasso['ADOL'].tolist(),
                   'NON-ADOL': overall_feature_importance_lasso['NON-ADOL'].tolist()},
                  index=overall_feature_importance_lasso['Feature'].tolist())
df.sort_values(by=['ADOL'], ascending=False, inplace=True)
plt.suptitle('Feature Importance: Lasso')
ax = df.plot.bar(rot=50, figsize=(12, 12))
ax.grid()
fig = ax.get_figure()
    
fig.savefig('graphs/feature_importance_lasso.png', bbox_inches='tight')

In [None]:
overall_feature_importance_rf = pd.merge(adol_feature_importance_rf, non_adol_feature_importance_rf, 
                                         on='Feature', how='outer')
overall_feature_importance_rf.fillna(0, inplace=True)
display_side_by_side(overall_feature_importance_rf, 4)

In [None]:
# feature importance for lasso across all ages
df = pd.DataFrame({'ADOL': overall_feature_importance_rf['ADOL'].tolist(),
                   'NON-ADOL': overall_feature_importance_rf['NON-ADOL'].tolist()},
                  index=overall_feature_importance_rf['Feature'].tolist())
df.sort_values(by=['ADOL'], ascending=False, inplace=True)
plt.suptitle('Feature Importance: Lasso')
ax = df.plot.bar(rot=50, figsize=(12, 12))
ax.grid()
fig = ax.get_figure()
    
fig.savefig('graphs/feature_importance_rf.png', bbox_inches='tight')

In [None]:
# features in top 10 of both models across all ages
feature_importance_intersection = np.intersect1d(overall_feature_importance_rf['Feature'], 
                                                 overall_feature_importance_lasso['Feature'])
print('Common Features:', *list(feature_importance_intersection), sep =', ')

In [None]:
# print out total notebook execution time
total_seconds = int(time.time() - start_time)
minutes = total_seconds // 60
seconds = total_seconds % 60
print("--- " + str(minutes) + " minutes " + str(seconds) + " seconds ---")