In [None]:
import pandas as pd
import os 
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

import statsmodels.api as sm
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve
from sklearn import metrics

import warnings
warnings.filterwarnings("ignore")

# Extract Data

In [None]:
path = 'data_backup/m2_mds_2024-05-31'
files = os.listdir(path)
df = []

for f in files:
    if 'cond_table' in f:
        tmp = pd.read_csv(path+'/'+f)
        print(f.split('_')[2])
        tmp['indicator'] = f.split('_')[2]
        df.append(tmp)

df              = pd.concat(df, ignore_index=True)
df['indicator'] = np.where((df.indicator=='NSCLC')|(df.indicator=='SCLC'),'Lung',df.indicator)
df = df[df['indicator']!='Renal']
df['label']     = np.where(df.event==1, 'Event','Censor')
df.head()

In [None]:
cohort = pd.DataFrame([['Lung'            ,0.0 ],['Pancreas'        ,1.0 ],['Melanoma'        ,2.0 ],
                       ['Colorectal'      ,3.0 ],['Prostate'        ,4.0 ],['Bladder'         ,5.0 ],
                       ['Breast'          ,6.0 ],['Gastricesophagus',7.0 ],['HCC'             ,8.0 ],
                       ['Renal'           ,9.0 ],['Ovarian'         ,10.0],['MDS'             ,11]],
                      columns=['indicator', 'cohort'])

In [None]:
final = pd.read_csv(path+'/inference/inference_final_data.csv')
final = final.assign(cohort=11)
final.rename(columns={'60':'P60','90':'P90','120':'P120','150':'P150'}, inplace=True)

final['label'] = np.where(final.event==1, 'Event','Censor')
final['event_temporal']=np.where((final.TTE>=0) & (final.TTE<=90),final.event, 0)
final['label_temporal']=np.where(final.event_temporal==1, 'Event','Censor')

final = final.merge(cohort)
final.head()

# Description of Training Date

In [None]:
agg = df.groupby('indicator',as_index=False).chai_patient_id.nunique().sort_values('chai_patient_id',ascending=True)
agg['share'] = round(agg.chai_patient_id/agg.chai_patient_id.sum()*100,2)

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,6))

ax[0].barh(agg.indicator, agg.chai_patient_id)
ax[0].set_title('Unique Patients By Cancer Indicator')
ax[0].set_xlabel('Unique Patients (#)')

ax[1].barh(agg.indicator, agg.share)
ax[1].set_title('Share of Unique Patients By Cancer Indicator')
ax[1].set_xlabel('Share of Unique Patients (%)')

plt.tight_layout(pad=.5)
plt.savefig(f'{path}/inference/cohort_count.png')
plt.show()

In [None]:
## High Frequency cancers 
high = df[df.indicator.isin(agg[agg.share>=5].indicator)]
high.shape

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(18, 9) )
fig.suptitle('Distributions of Patient Journey')

sns.boxplot(x = 'indicator', 
            y = 'duration', 
            hue = 'label', 
            data = high, 
            palette = 'Set2', 
            ax=ax[0],
            showfliers=False).set(title='Distribution of Duration between LOT Index and End Point')
sns.boxplot(x = 'indicator', 
            y = 'TTE', 
            hue = 'label', 
            data = high, 
            palette = 'Set2', 
            ax=ax[1],
            showfliers=False).set(title='Distribution of Duration between Random Observation Point and End Point')

plt.savefig(f'{path}/inference/Duration_TTE_Distribution.png')
plt.show()

# Description of Predicted Probability 

In [None]:
agg1 = final.groupby('indicator',as_index=False).chai_patient_id.nunique().sort_values(by='chai_patient_id',ascending=False)
agg1['share'] = agg1['chai_patient_id']/agg1['chai_patient_id'].sum()
agg1

In [None]:
sns.set_theme(rc={'figure.figsize':(15,6)})
sns.boxplot(x = 'indicator', 
            y = 'P90', 
            hue = 'label', 
            data = final, 
            palette = 'Set2', 
            showfliers=False).set(title='Predicted Probability Distribution over 90 Days by Indicator')

plt.savefig(f'{path}/inference/Pred_prob_distribution_90_indicator.png')
plt.show()

In [None]:
sns.set_theme(rc={'figure.figsize':(15,6)})
sns.boxplot(x = 'indicator', 
            y = 'P90', 
            hue = 'label_temporal', 
            data = final, 
            palette = 'Set2', 
            showfliers=False).set(title='Temporal Validation Based on 90-Day Predicted Probability')

plt.savefig(f'{path}/inference/Temporal_val_pred_prob_distribution_90.png')
plt.show()

In [None]:
final.head()

In [None]:
final[(final['label']=='Event') & (final['indicator']=='Lung')]['P90'].describe()

In [None]:
final[(final['label']=='Censor') & (final['indicator']=='Lung')]['P90'].describe()

In [None]:
def getF1(df):
    f1 = pd.DataFrame(columns=['th','precision','recall'])
    for th in range(0, 100, 1):
        tmp = df[['chai_patient_id','event','event_temporal','P90']]
        tmp['pred_class'] = np.where(tmp.P90>=th/100, 1, 0)
        tmp['actual_class'] = tmp.event_temporal
        f1 = f1.append({'th':th, 
                        'precision':precision_score(tmp.actual_class, tmp.pred_class),#, average='binary'),
                        'recall':recall_score(tmp.actual_class, tmp.pred_class),#, average='weighted'),
                        'f1':f1_score(tmp.actual_class, tmp.pred_class)},#, average = 'weighted')}, 
                       ignore_index=True)
    return f1

In [None]:
index = 0
for cancer in ['MDS']:
# for cancer in ['Lung']:    
    tmp = final[final.indicator==cancer]
    tmp['quantile'] = pd.qcut(tmp['P90'],q = 4, labels = False,duplicates='drop')
    
    ## Ranges per quantile
    probs = pd.DataFrame(columns=['indicator','quantile','min_prob','max_prob'])
    for qt in [0,1,2,3]:
        probs = probs.append({'indicator':cancer,
                              'quantile':qt,
                              'min_prob':tmp[tmp['quantile']==qt].P90.min(),
                              'max_prob':tmp[tmp['quantile']==qt].P90.max()
                             },ignore_index=True)
        
    ## Actual cases captured per quantile
    quint           = tmp[tmp.event_temporal==1].groupby(['indicator','quantile'],as_index=False).chai_patient_id.nunique()
    quint['actual'] = tmp[tmp.event_temporal==1].chai_patient_id.nunique()
    quint['total']  = tmp.chai_patient_id.nunique()
    quint.columns   = ['indicator','quantile','actual_quantile','actual','total']
    quint['share']  = quint.actual_quantile/quint.actual
    quint['pareto'] = quint["actual_quantile"].cumsum()/quint["actual_quantile"].sum()
    
    ## merge 
    if index>0:
        out = out.append(probs.merge(quint, on = ['indicator','quantile']), ignore_index=True)
        index += 1
    else:
        out = probs.merge(quint, on = ['indicator','quantile'])
        index += 1
        
    mrg = probs.merge(quint, on = ['indicator','quantile'])
    mrg['Q'] = 'Q'+mrg['quantile'].astype(str)+': ('+round(mrg.min_prob*100,2).astype(str)+'%, '+round(mrg.max_prob*100,2).astype(str)+'%)'
    
    ## Fetch F1, precision recall
    f1 = getF1(tmp)
    
    ## Get AUC data 
    auc = roc_auc_score(tmp['event_temporal'], tmp['P90'])
    fpr, tpr, _ = metrics.roc_curve(tmp['event_temporal'], tmp['P90'])
    
    ## Precision and recall
    precision, recall, thresholds = precision_recall_curve(tmp['event_temporal'], tmp['P90'],pos_label=1)
    
    ## Plot reports 
    fig, ax = plt.subplots(2, 2, figsize=(18, 9) )
    fig.suptitle('Temporal Validation of Pan Solid Model: '+cancer)
    fig.tight_layout(pad = 2)
    
    ## Plot ROC
    ax[0,0].plot(fpr,tpr,label="data 1, auc="+str(round(auc,2)))
    ax[0,0].legend(loc=4)
    ax[0,0].set_xlabel('FPR', fontsize=8)
    ax[0,0].set_ylabel('TPR', fontsize=8)
    ax[0,0].legend(loc=4)
    ax[0,0].set_title('AUC', fontsize=10)

    ## Plot F1/Precision/Recall w.r.t Threshold
    ax[0,1].plot(f1.th, f1.precision,label='Precision')
    ax[0,1].plot(f1.th, f1.recall,label='Recall')
    ax[0,1].plot(f1.th, f1.f1,label='F1')
    ax[0,1].legend(loc=4)
    ax[0,1].set_xlabel('Threshold', fontsize=8)
    ax[0,1].set_ylabel('F1/Precision/Recall', fontsize=8)
    ax[0,1].set_title('F1/Precision/Recall', fontsize=10)
    ax[0,1].legend(loc=4)
    
    ## Plot PRC
    ax[1,0].plot(recall, precision)
    ax[1,0].set_xlabel('Recall', fontsize=8)
    ax[1,0].set_ylabel('Precision', fontsize=8)
    ax[1,0].set_title('PRC', fontsize=10)
    ax[1,0].axvline(0.5,linestyle='--')
    ax[1,0].axhline(0.5,linestyle='--')
    
    ## Plot Share of actual cases per quantile
    ax[1,1].bar(mrg.Q, mrg.share*100)
    ax[1,1].set_xlabel('Quantile of P90', fontsize=8)
    ax[1,1].set_ylabel('Share of Actual Cases (%)', fontsize=8)
    ax[1,1].set_title('Share of Actual Cases By Quantile')
    
    fig.savefig(f'{path}/inference/Temporal Validation of Liquid Model: {cancer}.png')
    
out = out.sort_values(['total','indicator'], ascending=False)

In [None]:
final

In [None]:
def plot_stacked_bar_chart(df, location, cuts = 10):
    df['rank']     = df['P90'].rank(method='first')
    df['decile']   = pd.qcut(df['rank'],q = cuts, labels = False)
    n = tmp.shape[0]
    
    ## Compute percentile clusters 
    prc_d = df.groupby(['decile','event_temporal'],
                    as_index=False).chai_patient_id.nunique().pivot_table(values = 'chai_patient_id',
                                                                          index=['decile'], 
                                                                          columns = 'event_temporal').reset_index()
    ## Plot decile plot
    prc_d.columns = ['decile','Censor_abs','Event_abs']
    prc_d['total']= prc_d[['Censor_abs','Event_abs']].sum(axis=1)
    prc_d['Censor'] = prc_d.Censor_abs/prc_d.total
    prc_d['Event'] = prc_d.Event_abs/prc_d.total

    prc_d[['Event','Censor']].plot(kind='bar', 
                                   stacked=True, 
                                   ax = ax[location],
                                   colormap='tab20c'#, 
                                   #figsize=(15, 6)
                                  )
    ax[location].legend(loc="upper left", ncol=2)
    ax[location].set_xlabel("Decile")
    ax[location].set_ylabel("Share of Patients (%)")
    ax[location].set_title('Share of Actual Cases By '+str(cuts)+' groups of Predicted Probability')
    
    for n, x in enumerate([*prc_d[['Event','Censor']].index.values]):
        for (proportion, count, y_loc) in zip(prc_d[['Event','Censor']].loc[x],
                                              prc_d[['Event','Censor']].loc[x],
                                              prc_d[['Event','Censor']].loc[x].cumsum()):
            ax[location].text(x=n - 0.05,
                              y=(y_loc - proportion) + (proportion / 2),
                              s=f'({np.round(proportion * 100, 1)}%)',
                              color="black",
                              rotation=90, 
                              rotation_mode='anchor',
                              fontsize=10)#,
        #fontweight="bold")

In [None]:
for cancer in ['MDS']:
    tmp = final[final.indicator==cancer]
    n   = tmp.shape[0]
    ## Plot
    fig, ax = plt.subplot_mosaic([['left', 'right'],['middle', 'middle'],['bottom','bottom']],
                              constrained_layout=True, figsize=(15,15))    
    
    
    fig.suptitle('Temporal Validation of Pan Solid Model: '+cancer+', N = '+str(n))
    fig.tight_layout(pad = 3)
    
    ## Plot AUC
    auc = roc_auc_score(tmp['event_temporal'], tmp['P90'])
#    fpr, tpr, _ = metrics.roc_curve(tmp['event'], tmp['P90'])
    
#    ax['left'].plot(fpr,tpr,label="auc="+str(round(auc,2)))
#    ax['left'].set_xlabel('FPR', fontsize=8)
#    ax['left'].set_ylabel('TPR', fontsize=8)
#    ax['left'].legend(loc=4)
#    ax['left'].set_title('AUC', fontsize=10)
    
    ## Plot KDE plot
    sns.kdeplot(data        = tmp, 
                x           = 'P90', 
                hue         = 'label_temporal',
                fill        = True,
                common_norm = False, 
                alpha       = 0.5,
                ax          = ax['left']
               )
    ax['left'].set_xlabel('Predicted 90 Day Probability of Patient Availability', fontsize=8)
    ax['left'].set_ylabel('Probability Density', fontsize=8)
    ax['left'].legend(loc=4)
    ax['left'].set_title('90 day Predicted Probability', fontsize=10)

    ## Plot PRC
    precision, recall, thresholds = precision_recall_curve(tmp['event_temporal'], tmp['P90'])
    ax['right'].plot(recall, precision,label="AP="+str(round(auc,2)))
    ax['right'].legend(loc=4)
    ax['right'].set_xlabel('Recall', fontsize=8)
    ax['right'].set_ylabel('Precision', fontsize=8)
    ax['right'].set_title('PRC', fontsize=10)
    ax['right'].axvline(0.5,linestyle='--')
    ax['right'].axhline(0.5,linestyle='--')
    
    ## Precision/Recall by Decile
    plot_stacked_bar_chart(tmp, cuts=10, location = 'middle')
    
    ## Precision / Recall by 5th percentile
    plot_stacked_bar_chart(tmp, cuts=20, location = 'bottom')
    fig.savefig(f'{path}/inference/Temporal Validation of Liquid Model Decile: {cancer}.png')
    plt.show()

In [None]:
tmp[tmp['event_temporal']==0]['P90'].hist(bins=100)
tmp[tmp['event_temporal']==1]['P90'].hist(bins=100)

In [None]:
tmp['event'].value_counts()

In [None]:
tmp['event_temporal'].value_counts()