# Flatiron Health mPC: Survival metrics for appropriate chemo dosing 
**Background: Calculate survival metrics for emulated trials involving patients who receive appropriate upfront dosing of chemotherapeutics. Hazard ratio for the full cohort is calculated from a Cox-IPTW model. Restricted mean survival time and median overall survival are calculated for phenotypes using an IPTW-adjusted KM curve.**

## Part 1: Preprocessing

### 1.1 Import packages and create necessary functions

In [1]:
import numpy as np
import pandas as pd

from scipy import stats

from sksurv.nonparametric import kaplan_meier_estimator
from survive import KaplanMeier, SurvivalData

from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.plotting import add_at_risk_counts
from lifelines.utils import median_survival_times, restricted_mean_survival_time
from lifelines.statistics import logrank_test

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer 
from sklearn.linear_model import LogisticRegression
from sklearn.utils import resample

import warnings

In [2]:
# Function that returns number of rows and count of unique PatientIDs for a dataframe. 
def row_ID(dataframe):
    row = dataframe.shape[0]
    ID = dataframe['PatientID'].nunique()
    return row, ID

In [3]:
# Find index for value closest to input value. 
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

In [4]:
# Calculates median overeall suvival for risk groups. 
def mos(low, med, high, comp):
    low_os = low.median_survival_time_
    med_os = med.median_survival_time_
    high_os = high.median_survival_time_
    comp_os = comp.median_survival_time_
    mos = [low_os, med_os, high_os, comp_os]
    return (mos)

In [5]:
def rmst_mos_95ci(df, num_samples, drug, event, items_list, numerical_features, rmst_time):
    
    """
    Estimate the 95% confidence interval for RMST and mOS using bootstrap resampling.

    Parameters:
    - df: DataFrame containing survival data
    - num_samples: Number of bootstrap samples
    - drug: Treatment indicator variable
    - event: Event type ('death' or 'progression')
    - items_list: Feature list for IPTW 
    - numerical_features: List of numerical features
    - rmst_time: Time to calculate RMST 

    Returns:
    - mos_A_95: mOS 95% CI for treatment
    - mos_B_95: mOS 95% CI for control
    - rmst_A_95: RMST 95% CI for treatment
    - rmst_B_95: RMST 95% CI for control
    - difference_rmst_95: RMST 95% CI for difference between treatment and control 
    """
    
    np.random.seed(42)
    mos_A = []
    mos_B = []
    rmst_A_list = []
    rmst_B_list = []
    differences_rmst = []
    
    # Define variables based on the event type
    if event == 'death':
        time_column = 'timerisk_treatment'
        status_column = 'death_status'
        
    else:
        time_column = 'time_prog_treatment'
        status_column = 'pfs_status'
        
    # Set up preprocessor for logistical regression which will be for IPTW  
    numerical_transformer = Pipeline(steps = [
        ('imputer', SimpleImputer(strategy = 'median')),
        ('std_scaler', StandardScaler())])
        
    categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')
    categorical_features = list(df.select_dtypes(include = ['category']).columns)
        
    preprocessor = ColumnTransformer(
        transformers = [
            ('num', numerical_transformer, numerical_features),
            ('cat', categorical_transformer, categorical_features)],
        remainder = 'passthrough')
    
    # Boostrap HR 
    for _ in range(num_samples):
        
        # Resample data with replacement
        resampled_df = resample(df).drop(columns = ['ps', 'weight'])
        
        # Calculated IPTW for the resampled group 
        df_x = preprocessor.fit_transform(resampled_df.filter(items = items_list))
                                           
        df_lr = LogisticRegression(max_iter = 1000)
        df_lr.fit(df_x, resampled_df[drug])
        warnings.filterwarnings("ignore")
                                           
        pred = df_lr.predict_proba(df_x)        
        resampled_df['ps'] = pred[:, 1]                          
        resampled_df['weight'] = (
                np.where(resampled_df[drug] == 1, 1/resampled_df['ps'], 1/(1 - resampled_df['ps'])))
    
        # mOS from IPTW-KM
        kmf_A = KaplanMeierFitter()
        kmf_A.fit(resampled_df.query(f'{drug} == 1')[time_column]/30,
                  resampled_df.query(f'{drug} == 1')[status_column], 
                  weights = resampled_df.query(f'{drug} == 1')['weight'])

        kmf_B = KaplanMeierFitter()
        kmf_B.fit(resampled_df.query(f'{drug} == 0')[time_column]/30,
                  resampled_df.query(f'{drug} == 0')[status_column], 
                  weights = resampled_df.query(f'{drug} == 0')['weight'])
    
        mos_A.append(kmf_A.median_survival_time_)
        mos_B.append(kmf_B.median_survival_time_)
        
        # RMST from IPTW-KM
        rmst_A = restricted_mean_survival_time(kmf_A, rmst_time)
        rmst_B = restricted_mean_survival_time(kmf_B, rmst_time)
        
        rmst_A_list.append(rmst_A)
        rmst_B_list.append(rmst_B)
        differences_rmst.append(rmst_A - rmst_B)

    # Calculate the 95% confidence interval
    results = pd.Series({
    'mos_A_95': np.percentile(mos_A, [2.5, 97.5]),
    'mos_B_95': np.percentile(mos_B, [2.5, 97.5]),
    'rmst_A_95': np.percentile(rmst_A_list, [2.5, 97.5]),
    'rmst_B_95': np.percentile(rmst_B_list, [2.5, 97.5]),
    'difference_rmst_95': np.percentile(differences_rmst, [2.5, 97.5])
    })
    
    return results

In [6]:
cutoff = pd.read_csv('risk_cutoff_prostate.csv', index_col = 0)

## Part 2: In silico trials 

### CHAARTED: docetaxel vs. ADT in metastatic, castration-sensitive prostate cancer  

**INCLUSION**
* Untreated metastatic prostate cancer, except up to 4 months of ADT 
* Castration-sensitive
* Received ADT or docetaxel plus ADT
* Received correct dose of docetaxel

#### ADT

In [7]:
df_full = pd.read_csv('df_risk_crude.csv', index_col = 'PatientID', dtype = {'death_status': bool})
df_full.index.nunique()

18927

In [8]:
df_full.reset_index(inplace = True)

In [9]:
adt = pd.read_csv('Enhanced_MetPC_ADT.csv')

In [10]:
adt = (
    adt[adt['PatientID'].isin(df_full['PatientID'])]
    .query('TreatmentSetting == "Advanced"')
)

In [11]:
row_ID(adt)

(17863, 17863)

In [12]:
adt.loc[:, 'StartDate'] = pd.to_datetime(adt['StartDate'])

In [13]:
adt = adt.rename(columns = {'StartDate': 'StartDate_adt'})

In [14]:
df_full = pd.merge(df_full, adt[['PatientID', 'StartDate_adt']], on = 'PatientID', how = 'left')

In [15]:
row_ID(df_full)

(18927, 18927)

In [16]:
enhanced_met = pd.read_csv('Enhanced_MetProstate.csv')

In [17]:
enhanced_met = enhanced_met[enhanced_met['PatientID'].isin(df_full['PatientID'])]

In [18]:
enhanced_met.loc[:, 'MetDiagnosisDate'] = pd.to_datetime(enhanced_met['MetDiagnosisDate'])

In [19]:
enhanced_met.loc[:, 'CRPCDate'] = pd.to_datetime(enhanced_met['CRPCDate'])

In [20]:
df_full = pd.merge(df_full, enhanced_met[['PatientID', 'MetDiagnosisDate', 'CRPCDate']], on = 'PatientID', how = 'left')

In [21]:
row_ID(df_full)

(18927, 18927)

In [22]:
# Find all that start ADT within -120 to +90 days of metastatic diagnosis 
chaarted_adt = (
    df_full
    .assign(adt_diff = (df_full['StartDate_adt'] - df_full['MetDiagnosisDate']).dt.days)
    .query('adt_diff >= -120 and adt_diff <= 90')
)

In [23]:
# Find all that have missing CPRC date or date is >90 after metastatic diagnosis 
chaarted_adt = (
    chaarted_adt
    .assign(crpc_diff = (chaarted_adt['CRPCDate'] - chaarted_adt['MetDiagnosisDate']).dt.days)
    .query('crpc_diff > 90 or CRPCDate.isna()', engine = 'python')
)

In [24]:
row_ID(chaarted_adt)

(10475, 10475)

In [25]:
line_therapy = pd.read_csv('LineOfTherapy.csv')

In [26]:
zero = (
    line_therapy.query('LineNumber == 0')
    .PatientID
)

In [27]:
# Exclude patients with missing treatment information (ie, LineNumber == 0)
chaarted_adt = chaarted_adt[~chaarted_adt.PatientID.isin(zero)]

In [28]:
row_ID(chaarted_adt)

(10117, 10117)

In [29]:
line_therapy_cont = line_therapy.query('LineSetting != "nmCRPC"')

In [30]:
# List of FDA approved drugs for mPC as of July 2023. Clinical study drug is also included. 
fda_yes = [
    'Abiraterone',
    'Apalutamide',
    'Cabazitaxel',
    'Carboplatin',
    'Cisplatin',
    'Darolutamide',
    'Docetaxel',
    'Enzalutamide',
    'Mitoxantrone',
    'Olaparib',
    'Oxaliplatin',
    'Paclitaxel',
    'Pembrolizumab',
    'Radium-223',
    'Rucaparib',
    'Sipuleucel-T',
    'Clinical Study Drug'
]

In [31]:
line_therapy_cont = line_therapy_cont[line_therapy_cont['LineName'].str.contains('|'.join(fda_yes))]

In [32]:
line_therapy_cont = (
    line_therapy_cont
    .sort_values(by = ['PatientID', 'StartDate'], ascending = [True, True])
)

In [33]:
line_therapy_cont['line_number'] = (
    line_therapy_cont.groupby('PatientID')['LineNumber'].cumcount()+1
)

In [34]:
# First line therapy is in castrate-resistant setting 
fl_crpc = (
    line_therapy_cont[line_therapy_cont.PatientID.isin(chaarted_adt.PatientID)]
    .query('line_number == 1 & LineSetting == "mCRPC"')
    .PatientID
)

In [35]:
# Never received therapy other than ADT
notrt_adt = (
    chaarted_adt[~chaarted_adt.PatientID.isin(line_therapy_cont.PatientID)]
    .PatientID
)

In [36]:
adt_IDs = np.concatenate((fl_crpc, notrt_adt))

In [37]:
len(adt_IDs)

6218

In [38]:
chaarted_adt = chaarted_adt[chaarted_adt.PatientID.isin(adt_IDs)]

In [39]:
chaarted_adt.loc[:,'adt_dotx'] = 0

In [40]:
row_ID(chaarted_adt)

(6218, 6218)

In [41]:
chaarted_adt.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,other_met,prim_treatment,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,adt_dotx
1851,F152923E39963,M,white,unknown,81,COMMUNITY,N1,M1c,Adenocarcinoma,4 + 3 = 7,...,0.0,unknown,0.0,0.744153,2018-03-01,2018-02-26,NaT,3.0,,0
16592,F18E3F4AF30E3,M,other,unknown,71,BOTH,N0,M0,Adenocarcinoma,4 + 3 = 7,...,0.0,radiation,0.0,-1.424172,2020-03-04,2020-02-12,NaT,21.0,,0
15592,F7963F446BE53,M,white,unknown,77,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,8,...,0.0,unknown,0.0,-0.272794,2014-11-04,2014-11-04,2016-03-31,0.0,513.0,0


#### Docetaxel + ADT

In [42]:
# Find those that start ADT within -120 to 90 days of metastaic diagnosis 
chaarted_dotx = (
    df_full
    .assign(adt_diff = (df_full['StartDate_adt'] - df_full['MetDiagnosisDate']).dt.days)
    .query('adt_diff >= -120 and adt_diff <= 90')
)

In [43]:
# Find all that have missing CPRC date or date is >90 days metastatic diagnosis 
chaarted_dotx = (
    chaarted_dotx
    .assign(crpc_diff = (chaarted_dotx['CRPCDate'] - chaarted_dotx['MetDiagnosisDate']).dt.days)
    .query('crpc_diff > 90 or CRPCDate.isna()', engine = 'python')
)

In [44]:
row_ID(chaarted_dotx)

(10475, 10475)

In [45]:
# Find start time of first line of mHSPC therapy. 
line_therapy_fl = (
    line_therapy[line_therapy['PatientID'].isin(chaarted_dotx['PatientID'])]
    .query('LineSetting == "mHSPC"')
    .sort_values(by = ['PatientID', 'StartDate'], ascending = [True, True])
    .drop_duplicates(subset = ['PatientID'], keep = 'first')
    .rename(columns = {'StartDate': 'StartDate_dotx'})
)

In [46]:
row_ID(line_therapy_fl)

(4042, 4042)

In [47]:
line_therapy_fl.loc[:, 'StartDate_dotx'] = pd.to_datetime(line_therapy_fl['StartDate_dotx'])

In [48]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('Docetaxel')].LineName.value_counts().head(10)

Docetaxel                                         1363
Abiraterone,Docetaxel                               27
Carboplatin,Docetaxel                               12
Clinical Study Drug,Docetaxel                       11
Darolutamide,Docetaxel                               9
Docetaxel,Enzalutamide                               6
Carboplatin,Docetaxel,Estramustine                   5
Docetaxel,Ketoconazole                               3
Apalutamide,Docetaxel                                2
Carboplatin,Docetaxel,Estramustine,Thalidomide       2
Name: LineName, dtype: int64

In [49]:
line_dotx = line_therapy_fl.query('LineName == "Docetaxel"')

In [50]:
row_ID(line_dotx)

(1363, 1363)

In [51]:
chaarted_dotx = pd.merge(chaarted_dotx, line_dotx[['PatientID', 'StartDate_dotx']], on = 'PatientID', how = 'left')

In [52]:
row_ID(chaarted_dotx)

(10475, 10475)

In [53]:
# Find all patients that start enzalutamide within 180 days of metastatic diagnosis
chaarted_dotx = (
    chaarted_dotx
    .assign(fl_diff = (chaarted_dotx['StartDate_dotx'] - chaarted_dotx['StartDate_adt']).dt.days)
    .query('fl_diff >= -90 and fl_diff <= 90')
)

In [54]:
len(chaarted_dotx)

1179

In [55]:
chaarted_dotx.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,prim_treatment,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,StartDate_dotx,fl_diff
8612,FFDE777AC9676,M,unknown,unknown,68,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,9,...,unknown,0.0,-0.609793,2021-02-22,2021-01-21,NaT,32.0,,2021-03-18,24.0
6245,F8A9F738A55B2,M,white,unknown,81,ACADEMIC,NX,M1,Adenocarcinoma,Unknown / Not documented,...,unknown,0.0,0.683789,2020-03-04,2020-02-13,NaT,20.0,,2020-03-11,7.0
10193,F11302A8DC709,M,other,unknown,71,COMMUNITY,N0,M1b,"Prostate cancer, NOS",Unknown / Not documented,...,unknown,0.0,0.515485,2019-02-27,2019-02-01,2019-08-07,26.0,187.0,2019-04-10,42.0


In [56]:
chaarted_dotx.loc[:,'adt_dotx'] = 1

In [57]:
chaarted = pd.concat([chaarted_adt, chaarted_dotx], ignore_index = True)

In [58]:
row_ID(chaarted)

(7397, 7397)

In [59]:
chaarted.adt_dotx.value_counts(dropna = False)

0    6218
1    1179
Name: adt_dotx, dtype: int64

In [60]:
chaarted.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,adt_dotx,StartDate_dotx,fl_diff
5724,F95405E6D9132,M,white,unknown,77,ACADEMIC,Unknown / Not documented,M1,Adenocarcinoma,9,...,0.0,-0.284602,2019-05-02,2019-04-10,2021-06-04,22.0,786.0,0,NaT,
2289,F90855E5D12B4,M,white,unknown,67,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,9,...,0.0,-0.476806,2019-01-23,2019-01-15,2019-07-25,8.0,191.0,0,NaT,
2842,F3F342A0D9882,M,white,unknown,78,COMMUNITY,Unknown / Not documented,M0,Adenocarcinoma,8,...,0.0,-0.127889,2015-02-19,2015-01-06,NaT,44.0,,0,NaT,


#### Docetaxel dosing

In [61]:
med_order = pd.read_csv('MedicationOrder.csv', low_memory = False)

In [62]:
med_order['ExpectedStartDate'] = np.where(med_order['ExpectedStartDate'].isna(), 
                                          med_order['OrderedDate'], 
                                          med_order['ExpectedStartDate'])

In [63]:
med_order.loc[:, 'ExpectedStartDate'] = pd.to_datetime(med_order['ExpectedStartDate'])

In [64]:
chaarted_dotx['StartDate_dotx'] = pd.to_datetime(chaarted_dotx['StartDate_dotx'])

In [65]:
med_order_dotx = (
    med_order[med_order['PatientID'].isin(chaarted_dotx.PatientID)]
    .query('CommonDrugName == "docetaxel"')
)

In [66]:
med_order_dotx.shape

(9407, 18)

In [67]:
med_order_dotx = pd.merge(med_order_dotx, 
                          chaarted_dotx[['PatientID', 'StartDate_dotx']], 
                          on = 'PatientID', 
                          how = 'left')

In [68]:
med_order_dotx.shape

(9407, 19)

In [69]:
med_order_dotx.loc[:, 'date_diff'] = (med_order_dotx['ExpectedStartDate'] - med_order_dotx['StartDate_dotx']).dt.days.abs()

In [70]:
med_order_dotx = med_order_dotx.query('date_diff <= 14')

In [71]:
dotx_index = med_order_dotx.groupby('PatientID')['date_diff'].idxmin()

In [72]:
dotx_dose = med_order_dotx.loc[dotx_index].query('RelativeOrderedUnits == "mg/m2"')[['PatientID', 'RelativeOrderedAmount']]

In [73]:
dotx_dose = dotx_dose.rename(columns = {'RelativeOrderedAmount': 'dotx_dose_mgm2'})

In [74]:
dotx_IDs = dotx_dose.query('dotx_dose_mgm2 >= 75').PatientID

In [75]:
chaarted_dotx = chaarted_dotx[chaarted_dotx['PatientID'].isin(dotx_IDs)]

In [76]:
chaarted = pd.concat([chaarted_adt, chaarted_dotx])

In [77]:
row_ID(chaarted)

(7166, 7166)

#### Time from ADT treatment to death or censor 

In [78]:
mortality_tr = pd.read_csv('mortality_cleaned_tr.csv')

In [79]:
mortality_te = pd.read_csv('mortality_cleaned_te.csv')

In [80]:
mortality_tr = mortality_tr[['PatientID', 'death_date', 'last_activity']]

In [81]:
mortality_te = mortality_te[['PatientID', 'death_date', 'last_activity']]

In [82]:
mortality = pd.concat([mortality_tr, mortality_te], ignore_index = True)
row_ID(mortality)

(18927, 18927)

In [83]:
mortality.loc[:, 'last_activity'] = pd.to_datetime(mortality['last_activity'])

In [84]:
mortality.loc[:, 'death_date'] = pd.to_datetime(mortality['death_date'])

In [85]:
len(mortality)

18927

In [86]:
chaarted = pd.merge(chaarted, mortality, on = 'PatientID', how = 'left')

In [87]:
row_ID(chaarted)

(7166, 7166)

In [88]:
conditions = [
    (chaarted['death_status'] == 1),
    (chaarted['death_status'] == 0)]

choices = [
    (chaarted['death_date'] - chaarted['StartDate_adt']).dt.days,
    (chaarted['last_activity'] - chaarted['StartDate_adt']).dt.days]

chaarted.loc[:, 'timerisk_treatment'] = np.select(conditions, choices)

chaarted = chaarted.query('timerisk_treatment >= 0')

#### Patient count 

In [89]:
low_cutoff_chaarted = cutoff.loc['chaarted'].low

In [90]:
high_cutoff_chaarted = cutoff.loc['chaarted'].high

In [91]:
print('Docetaxel + ADT:',  chaarted.query('adt_dotx == 1').shape[0])
print('High risk:', chaarted.query('adt_dotx == 1').query('risk_score >= @high_cutoff_chaarted').shape[0])
print('Med risk:', chaarted.query('adt_dotx == 1').query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').shape[0])
print('Low risk:', chaarted.query('adt_dotx == 1').query('risk_score <= @low_cutoff_chaarted').shape[0])

Docetaxel + ADT: 948
High risk: 241
Med risk: 346
Low risk: 361


In [92]:
print('ADT:',  chaarted.query('adt_dotx == 0').shape[0])
print('High risk:', chaarted.query('adt_dotx == 0').query('risk_score >= @high_cutoff_chaarted').shape[0])
print('Med risk:', chaarted.query('adt_dotx == 0').query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').shape[0])
print('Low risk:', chaarted.query('adt_dotx == 0').query('risk_score <= @low_cutoff_chaarted').shape[0])

ADT: 6218
High risk: 2150
Med risk: 2041
Low risk: 2027


#### Survival curves with covariate balancing 

In [93]:
chaarted = chaarted.set_index('PatientID')

In [94]:
chaarted['met_cat'] = pd.cut(chaarted['met_year'],
                             bins = [2010, 2015, float('inf')],
                             labels = ['11-15', '16-22'])

In [95]:
conditions = [
    ((chaarted['ecog_diagnosis'] == "1.0") | (chaarted['ecog_diagnosis'] == "0.0")),  
    ((chaarted['ecog_diagnosis'] == "2.0") | (chaarted['ecog_diagnosis'] == "3.0"))
]

choices = ['lt_2', 'gte_2']

chaarted['ecog_2'] = np.select(conditions, choices, default = 'unknown')

In [96]:
chaarted_iptw = chaarted.filter(items = ['death_status',
                                         'timerisk_treatment',
                                         'adt_dotx',
                                         'age',
                                         'race',
                                         'p_type',
                                         'met_cat',
                                         'delta_met_diagnosis',
                                         'commercial',
                                         'medicare',
                                         'medicaid',
                                         'ecog_2',
                                         'prim_treatment',
                                         'PSAMetDiagnosis',
                                         'albumin_diag', 
                                         'weight_pct_change',
                                         'risk_score'])

In [97]:
chaarted_iptw.dtypes

death_status               bool
timerisk_treatment      float64
adt_dotx                  int64
age                       int64
race                     object
p_type                   object
met_cat                category
delta_met_diagnosis       int64
commercial              float64
medicare                float64
medicaid                float64
ecog_2                   object
prim_treatment           object
PSAMetDiagnosis         float64
albumin_diag            float64
weight_pct_change       float64
risk_score              float64
dtype: object

In [98]:
to_be_categorical = list(chaarted_iptw.select_dtypes(include = ['object']).columns)

In [99]:
to_be_categorical

['race', 'p_type', 'ecog_2', 'prim_treatment']

In [100]:
to_be_categorical.append('met_cat')

In [101]:
# Convert variables in list to categorical.
for x in list(to_be_categorical):
    chaarted_iptw[x] = chaarted_iptw[x].astype('category')

In [102]:
# List of numeric variables, excluding binary variables. 
numerical_features = ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score']

# Transformer will first calculate column median and impute, and then apply a standard scaler. 
numerical_transformer = Pipeline(steps = [
    ('imputer', SimpleImputer(strategy = 'median')),
    ('std_scaler', StandardScaler())])

In [103]:
# List of categorical features.
categorical_features = list(chaarted_iptw.select_dtypes(include = ['category']).columns)

# One-hot-encode categorical features.
categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')

In [104]:
preprocessor = ColumnTransformer(
    transformers = [
        ('num', numerical_transformer, numerical_features),
        ('cat', categorical_transformer, categorical_features)],
    remainder = 'passthrough')

In [105]:
chaarted_iptw_low = (
    chaarted_iptw
    .query('risk_score <= @low_cutoff_chaarted'))

chaarted_iptw_med = (
    chaarted_iptw
    .query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted'))

chaarted_iptw_high = (
    chaarted_iptw
    .query('risk_score >= @high_cutoff_chaarted'))

chaarted_iptw_all = chaarted_iptw

In [106]:
chaarted_low_x = preprocessor.fit_transform(chaarted_iptw_low.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'delta_met_diagnosis',
                                                                              'met_cat',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid',
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag', 
                                                                              'weight_pct_change',
                                                                              'risk_score']))

chaarted_med_x = preprocessor.fit_transform(chaarted_iptw_med.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'delta_met_diagnosis',
                                                                              'met_cat',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid',
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag', 
                                                                              'weight_pct_change', 
                                                                              'risk_score']))

chaarted_high_x = preprocessor.fit_transform(chaarted_iptw_high.filter(items = ['age',
                                                                                'race',
                                                                                'p_type',
                                                                                'delta_met_diagnosis',
                                                                                'met_cat',
                                                                                'commercial',
                                                                                'medicare',
                                                                                'medicaid',
                                                                                'ecog_2',
                                                                                'prim_treatment',
                                                                                'PSAMetDiagnosis', 
                                                                                'albumin_diag', 
                                                                                'weight_pct_change',
                                                                                'risk_score']))

chaarted_all_x = preprocessor.fit_transform(chaarted_iptw_all.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'delta_met_diagnosis',
                                                                              'met_cat',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid',
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag', 
                                                                              'weight_pct_change', 
                                                                              'risk_score']))

In [107]:
lr_chaarted_low = LogisticRegression(max_iter = 1000)
lr_chaarted_low.fit(chaarted_low_x, chaarted_iptw_low['adt_dotx'])

LogisticRegression(max_iter=1000)

In [108]:
lr_chaarted_med = LogisticRegression(max_iter = 1000)
lr_chaarted_med.fit(chaarted_med_x, chaarted_iptw_med['adt_dotx'])

LogisticRegression(max_iter=1000)

In [109]:
lr_chaarted_high = LogisticRegression(max_iter = 1000)
lr_chaarted_high.fit(chaarted_high_x, chaarted_iptw_high['adt_dotx'])

LogisticRegression(max_iter=1000)

In [110]:
lr_chaarted_all = LogisticRegression(max_iter = 1000)
lr_chaarted_all.fit(chaarted_all_x, chaarted_iptw_all['adt_dotx'])

LogisticRegression(max_iter=1000)

In [111]:
pred_low = lr_chaarted_low.predict_proba(chaarted_low_x)
pred_med = lr_chaarted_med.predict_proba(chaarted_med_x)
pred_high = lr_chaarted_high.predict_proba(chaarted_high_x)
pred_all = lr_chaarted_all.predict_proba(chaarted_all_x)

In [112]:
chaarted_iptw_low['ps'] = pred_low[:, 1]
chaarted_iptw_med['ps'] = pred_med[:, 1]
chaarted_iptw_high['ps'] = pred_high[:, 1]
chaarted_iptw_all['ps'] = pred_all[:, 1]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [113]:
chaarted_iptw_low['weight'] = (
    np.where(chaarted_iptw_low['adt_dotx'] == 1, 1/chaarted_iptw_low['ps'], 1/(1 - chaarted_iptw_low['ps'])))

chaarted_iptw_med['weight'] = (
    np.where(chaarted_iptw_med['adt_dotx'] == 1, 1/chaarted_iptw_med['ps'], 1/(1 - chaarted_iptw_med['ps'])))

chaarted_iptw_high['weight'] = (
    np.where(chaarted_iptw_high['adt_dotx'] == 1, 1/chaarted_iptw_high['ps'], 1/(1 - chaarted_iptw_high['ps'])))

chaarted_iptw_all['weight'] = (
    np.where(chaarted_iptw_all['adt_dotx'] == 1, 1/chaarted_iptw_all['ps'], 1/(1 - chaarted_iptw_all['ps'])))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [114]:
# Low KM curves
kmf_low_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_low_adt_chaarted_iptw = KaplanMeierFitter()

kmf_low_dotx_chaarted_iptw.fit(
    chaarted_iptw_low.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_low.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_low.query('adt_dotx == 1')['weight'])

kmf_low_adt_chaarted_iptw.fit(
    chaarted_iptw_low.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_low.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_low.query('adt_dotx == 0')['weight'])

# Med KM curves
kmf_med_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_med_adt_chaarted_iptw = KaplanMeierFitter()

kmf_med_dotx_chaarted_iptw.fit(
    chaarted_iptw_med.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_med.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_med.query('adt_dotx == 1')['weight'])

kmf_med_adt_chaarted_iptw.fit(
    chaarted_iptw_med.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_med.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_med.query('adt_dotx == 0')['weight'])

# High KM curves 
kmf_high_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_high_adt_chaarted_iptw = KaplanMeierFitter()

kmf_high_dotx_chaarted_iptw.fit(
    chaarted_iptw_high.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_high.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_high.query('adt_dotx == 1')['weight'])

kmf_high_adt_chaarted_iptw.fit(
    chaarted_iptw_high.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_high.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_high.query('adt_dotx == 0')['weight'])

# All KM curves 
kmf_all_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_all_adt_chaarted_iptw = KaplanMeierFitter()

kmf_all_dotx_chaarted_iptw.fit(
    chaarted_iptw_all.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_all.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_all.query('adt_dotx == 1')['weight'])

kmf_all_adt_chaarted_iptw.fit(
    chaarted_iptw_all.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_all.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_all.query('adt_dotx == 0')['weight'])

  It's important to know that the naive variance estimates of the coefficients are biased. Instead use Monte Carlo to
  estimate the variances. See paper "Variance estimation when using inverse probability of treatment weighting (IPTW) with survival analysis"
  or "Adjusted Kaplan-Meier estimator and log-rank test with inverse probability of treatment weighting for survival data."
                  


<lifelines.KaplanMeierFitter:"KM_estimate", fitted with 7182.66 total observations, 3233.44 right-censored observations>

#### Calculating survival metrics 

In [115]:
dotx_chaarted_median_os = mos(kmf_low_dotx_chaarted_iptw,
                              kmf_med_dotx_chaarted_iptw,
                              kmf_high_dotx_chaarted_iptw,
                              kmf_all_dotx_chaarted_iptw)

adt_chaarted_median_os = mos(kmf_low_adt_chaarted_iptw,
                             kmf_med_adt_chaarted_iptw,
                             kmf_high_adt_chaarted_iptw,
                             kmf_all_adt_chaarted_iptw)

In [116]:
chaarted_iptw_all_imputed = chaarted_iptw_all.copy()
chaarted_iptw_all_imputed['albumin_diag'] = chaarted_iptw_all_imputed['albumin_diag'].fillna(chaarted_iptw_all_imputed['albumin_diag'].median())
chaarted_iptw_all_imputed['weight_pct_change'] = chaarted_iptw_all_imputed['weight_pct_change'].fillna(chaarted_iptw_all_imputed['weight_pct_change'].median())
chaarted_iptw_all_imputed['PSAMetDiagnosis'] = chaarted_iptw_all_imputed['PSAMetDiagnosis'].fillna(chaarted_iptw_all_imputed['PSAMetDiagnosis'].median())

In [117]:
chaarted_hr_all = CoxPHFitter()
chaarted_hr_all.fit(chaarted_iptw_all_imputed,
                    duration_col = 'timerisk_treatment',
                    event_col = 'death_status',
                    formula = 'adt_dotx + age + race + p_type + delta_met_diagnosis + met_cat + commercial + medicare + medicaid + ecog_2 + prim_treatment + PSAMetDiagnosis + albumin_diag + weight_pct_change + risk_score',
                    weights_col = 'weight', 
                    robust = True)

<lifelines.CoxPHFitter: fitted with 13469 total observations, 6554.9 right-censored observations>

In [118]:
chaarted_all_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_all,
                                         1000,
                                         'adt_dotx',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [119]:
chaarted_low_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_low,
                                         1000,
                                         'adt_dotx',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [120]:
chaarted_med_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_med,
                                         1000,
                                         'adt_dotx',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [121]:
chaarted_high_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_high,
                                          1000,
                                          'adt_dotx',
                                          'death',
                                          ['age',
                                           'race',
                                           'p_type',
                                           'delta_met_diagnosis',
                                           'met_cat',
                                           'commercial',
                                           'medicare',
                                           'medicaid',
                                           'ecog_2',
                                           'prim_treatment',
                                           'PSAMetDiagnosis', 
                                           'albumin_diag', 
                                           'weight_pct_change',
                                           'risk_score'],
                                          ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                          60)

In [122]:
chaarted_data = [
    {'trial_name': 'CHAARTED', 
     'risk_group': 'low', 
     'r_trt_mos': dotx_chaarted_median_os[0],
     'r_trt_mos_95': chaarted_low_rmst_mos_95.mos_A_95,
     'r_cont_mos': adt_chaarted_median_os[0],
     'r_cont_mos_95': chaarted_low_rmst_mos_95.mos_B_95,
     'r_mos_diff': dotx_chaarted_median_os[0] - adt_chaarted_median_os[0], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     'trt_rmst': restricted_mean_survival_time(kmf_low_dotx_chaarted_iptw, 60),
     'trt_rmst_95': chaarted_low_rmst_mos_95.rmst_A_95,
     'cont_rmst': restricted_mean_survival_time(kmf_low_adt_chaarted_iptw, 60),
     'cont_rmst_95': chaarted_low_rmst_mos_95.rmst_B_95,
     'diff_rmst': restricted_mean_survival_time(kmf_low_dotx_chaarted_iptw, 60) - restricted_mean_survival_time(kmf_low_adt_chaarted_iptw, 60),
     'diff_rmst_95': chaarted_low_rmst_mos_95.difference_rmst_95,
     'rcount': chaarted.query('risk_score <= @low_cutoff_chaarted').shape[0],
     'rcount_chemo': chaarted.query('risk_score <= @low_cutoff_chaarted').query('adt_dotx == 1').shape[0]},
    
    {'trial_name': 'CHAARTED', 
     'risk_group': 'medium', 
     'r_trt_mos': dotx_chaarted_median_os[1],
     'r_trt_mos_95': chaarted_med_rmst_mos_95.mos_A_95,
     'r_cont_mos': adt_chaarted_median_os[1],
     'r_cont_mos_95': chaarted_med_rmst_mos_95.mos_B_95,
     'r_mos_diff': dotx_chaarted_median_os[1] - adt_chaarted_median_os[1], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     'trt_rmst': restricted_mean_survival_time(kmf_med_dotx_chaarted_iptw, 60),
     'trt_rmst_95': chaarted_med_rmst_mos_95.rmst_A_95,
     'cont_rmst': restricted_mean_survival_time(kmf_med_adt_chaarted_iptw, 60),
     'cont_rmst_95': chaarted_med_rmst_mos_95.rmst_B_95,
     'diff_rmst': restricted_mean_survival_time(kmf_med_dotx_chaarted_iptw, 60) - restricted_mean_survival_time(kmf_med_adt_chaarted_iptw, 60),
     'diff_rmst_95': chaarted_med_rmst_mos_95.difference_rmst_95,
     'rcount': chaarted.query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').shape[0],
     'rcount_chemo': chaarted.query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').query('adt_dotx == 1').shape[0]},
    
    {'trial_name': 'CHAARTED', 
     'risk_group': 'high', 
     'r_trt_mos': dotx_chaarted_median_os[2],
     'r_trt_mos_95': chaarted_high_rmst_mos_95.mos_A_95,
     'r_cont_mos': adt_chaarted_median_os[2],
     'r_cont_mos_95': chaarted_high_rmst_mos_95.mos_B_95,
     'r_mos_diff': dotx_chaarted_median_os[2] - adt_chaarted_median_os[2], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     'trt_rmst': restricted_mean_survival_time(kmf_high_dotx_chaarted_iptw, 60),
     'trt_rmst_95': chaarted_high_rmst_mos_95.rmst_A_95,
     'cont_rmst': restricted_mean_survival_time(kmf_high_adt_chaarted_iptw, 60),
     'cont_rmst_95': chaarted_high_rmst_mos_95.rmst_B_95,
     'diff_rmst': restricted_mean_survival_time(kmf_high_dotx_chaarted_iptw, 60) - restricted_mean_survival_time(kmf_high_adt_chaarted_iptw, 60),
     'diff_rmst_95': chaarted_high_rmst_mos_95.difference_rmst_95,
     'rcount': chaarted.query('risk_score >= @high_cutoff_chaarted').shape[0],
     'rcount_chemo': chaarted.query('risk_score >= @high_cutoff_chaarted').query('adt_dotx == 1').shape[0]},
    
    {'trial_name': 'CHAARTED', 
     'risk_group': 'all', 
     'r_hr': chaarted_hr_all.hazard_ratios_['adt_dotx'],
     'r_hr_95': [chaarted_hr_all.summary.loc['adt_dotx']['exp(coef) lower 95%'], chaarted_hr_all.summary.loc['adt_dotx']['exp(coef) upper 95%']],
     'r_trt_mos': dotx_chaarted_median_os[3],
     'r_trt_mos_95': chaarted_all_rmst_mos_95.mos_A_95,
     'r_cont_mos': adt_chaarted_median_os[3],
     'r_cont_mos_95': chaarted_all_rmst_mos_95.mos_B_95,
     'r_mos_diff': dotx_chaarted_median_os[3] - adt_chaarted_median_os[3], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     'rcount': chaarted.shape[0], 
     'rcount_chemo': chaarted.query('adt_dotx == 1').shape[0]}
]

## Part 3. Combining dictionaries 

In [123]:
rtrials_dc_mos = pd.DataFrame(chaarted_data)

In [124]:
rtrials_dc_mos

Unnamed: 0,trial_name,risk_group,r_trt_mos,r_trt_mos_95,r_cont_mos,r_cont_mos_95,r_mos_diff,rct_trt_arm,rct_cont_arm,rct_mos_diff,trt_rmst,trt_rmst_95,cont_rmst,cont_rmst_95,diff_rmst,diff_rmst_95,rcount,rcount_chemo,r_hr,r_hr_95
0,CHAARTED,low,80.6,"[61.7, nan]",68.0,"[63.36666666666667, 73.10999999999999]",12.6,57.6,44.0,13.6,49.796901,"[47.051933083473564, 52.524283582442614]",47.98562,"[47.01948226038321, 48.8800089610551]",1.811281,"[-1.0512328163416451, 4.788807710004043]",2388,361,,
1,CHAARTED,medium,53.633333,"[47.3, 60.56666666666667]",40.7,"[38.9, 43.266666666666666]",12.933333,57.6,44.0,13.6,46.120593,"[42.73538334360622, 48.91362771520557]",39.333863,"[38.47958185085627, 40.24858630726919]",6.786731,"[3.5097680783937584, 9.675200275337048]",2387,346,,
2,CHAARTED,high,33.933333,"[26.6, 37.266666666666666]",21.766667,"[20.5, 23.0]",12.166667,57.6,44.0,13.6,33.291605,"[28.78578051770819, 36.96250730469531]",25.693386,"[24.845824834082585, 26.56462173504901]",7.598219,"[3.036143661792651, 11.182298080872085]",2391,241,,
3,CHAARTED,all,51.866667,"[43.63333333333333, 58.666666666666664]",38.366667,"[37.4325, 39.968333333333334]",13.5,57.6,44.0,13.6,,,,,,,7166,948,0.628292,"[0.5465610039502893, 0.7222437605448319]"


In [125]:
rtrials_dc_mos.to_csv('rtrials_dc_mos.csv', index = False)