# Flatiron Health mPC: Survival metrics for strict elgibility criteria
**Background: Calculate survival metrics for emulated trials involving patients meeting strict elgibliity criteria. Hazard ratio for the full cohort is calculated from a Cox-IPTW model. Restricted mean survival time and median overall survival are calculated for phenotypes using an IPTW-adjusted KM curve.** 

## Part 1: Identify patients with exclusion criteria

In [1]:
import numpy as np
import pandas as pd

In [2]:
# Function that returns number of rows and count of unique PatientIDs for a dataframe. 
def row_ID(dataframe):
    row = dataframe.shape[0]
    ID = dataframe['PatientID'].nunique()
    return row, ID

In [3]:
train = pd.read_csv('train_full.csv')
row_ID(train)

(15141, 15141)

In [4]:
test = pd.read_csv('test_full.csv')
row_ID(test)

(3786, 3786)

In [5]:
df = pd.concat([train, test], ignore_index = True)
row_ID(df)

(18927, 18927)

In [6]:
df.query('cns_met == 0').sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,peritoneum_met,liver_met,other_gi_met,cns_met,bone_met,lymph_met,kidney_bladder_met,other_met,prim_treatment,early_adt
2782,F344700ABCC3D,M,white,unknown,76,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,3 + 4 = 7,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,unknown,0.0
16841,FC22B8818581B,M,white,unknown,75,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,9,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,unknown,0.0
10344,FC8175DD6CFA1,M,white,unknown,66,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,8,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,unknown,0.0


### 1. Cardiac disease (CHF or MI) in the preceding 6 months 

In [7]:
diagnosis = pd.read_csv('Diagnosis.csv')

In [8]:
diagnosis = diagnosis[diagnosis['PatientID'].isin(df['PatientID'])]       

In [9]:
diagnosis.loc[:, 'DiagnosisDate'] = pd.to_datetime(diagnosis['DiagnosisDate'])

In [10]:
enhanced_met = pd.read_csv('Enhanced_MetProstate.csv')

In [11]:
enhanced_met.loc[:, 'MetDiagnosisDate'] = pd.to_datetime(enhanced_met['MetDiagnosisDate'])

In [12]:
row_ID(diagnosis)

(605074, 18927)

In [13]:
diagnosis = pd.merge(diagnosis, enhanced_met[['PatientID', 'MetDiagnosisDate']], on = 'PatientID', how = 'left')

In [14]:
row_ID(diagnosis)

(605074, 18927)

In [15]:
diagnosis.loc[:, 'date_diff'] = (diagnosis['DiagnosisDate'] - diagnosis['MetDiagnosisDate']).dt.days

In [16]:
diagnosis.loc[:, 'diagnosis_code'] = diagnosis['DiagnosisCode'].replace('\.', '', regex = True)

In [17]:
# ICD-9 dataframe with unique codes for each patient. 
diagnosis_9 = (
    diagnosis
    .query('date_diff <= 30 and date_diff > -365')
    .query('DiagnosisCodeSystem == "ICD-9-CM"')
    .drop_duplicates(subset = (['PatientID', 'DiagnosisCode']), keep = 'first')
    .filter(items = ['PatientID', 'DiagnosisCode', 'diagnosis_code'])
)

In [18]:
cardiac_9_IDs = (
    diagnosis_9[diagnosis_9['diagnosis_code'].str.match('428|'
                                                        '410')].PatientID.unique())

In [19]:
len(cardiac_9_IDs)

15

In [20]:
# ICD-9 dataframe with unique codes for each patient. 
diagnosis_10 = (
    diagnosis
    .query('date_diff <= 30 and date_diff > -365')
    .query('DiagnosisCodeSystem == "ICD-10-CM"')
    .drop_duplicates(subset = (['PatientID', 'DiagnosisCode']), keep = 'first')
    .filter(items = ['PatientID', 'DiagnosisCode', 'diagnosis_code'])
)

In [21]:
cardiac_10_IDs = (
    diagnosis_10[diagnosis_10['diagnosis_code'].str.match('I50|'
                                                          'I21')].PatientID.unique())

In [22]:
len(cardiac_10_IDs)

195

In [23]:
cardiac_IDs = np.unique(np.concatenate([cardiac_9_IDs, cardiac_10_IDs]))

In [24]:
len(cardiac_IDs)

210

### 2. Viral hepatitis and chronic liver disease 

In [25]:
liv_9_IDs = (
    diagnosis_9[diagnosis_9['diagnosis_code'].str.match('070|'
                                                        '571')].PatientID.unique())

In [26]:
len(liv_9_IDs)

18

In [27]:
liv_10_IDs = (
    diagnosis_10[diagnosis_10['diagnosis_code'].str.match('B1[56789]|'
                                                          'K7[0234]')].PatientID.unique())

In [28]:
len(liv_10_IDs)

65

In [29]:
liv_IDs = np.unique(np.concatenate([liv_9_IDs, liv_10_IDs]))

In [30]:
len(liv_IDs)

82

In [31]:
del diagnosis
del diagnosis_10
del diagnosis_9
del enhanced_met

## Part 2: In-silico trials 

### Import packages and create necessary functions

In [32]:
from scipy import stats

from sksurv.nonparametric import kaplan_meier_estimator
from survive import KaplanMeier, SurvivalData

from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.plotting import add_at_risk_counts
from lifelines.utils import median_survival_times, restricted_mean_survival_time
from lifelines.statistics import logrank_test

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer 
from sklearn.linear_model import LogisticRegression
from sklearn.utils import resample

import warnings

In [33]:
# Function that returns number of rows and count of unique PatientIDs for a dataframe. 
def row_ID(dataframe):
    row = dataframe.shape[0]
    ID = dataframe['PatientID'].nunique()
    return row, ID

In [34]:
# Find index for value closest to input value. 
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

In [35]:
# Calculates median overeall suvival for risk groups. 
def mos(low, med, high, comp):
    low_os = low.median_survival_time_
    med_os = med.median_survival_time_
    high_os = high.median_survival_time_
    comp_os = comp.median_survival_time_
    mos = [low_os, med_os, high_os, comp_os]
    return (mos)

In [36]:
def rmst_mos_95ci(df, num_samples, drug, event, items_list, numerical_features, rmst_time):
    
    """
    Estimate the 95% confidence interval for RMST and mOS using bootstrap resampling.

    Parameters:
    - df: DataFrame containing survival data
    - num_samples: Number of bootstrap samples
    - drug: Treatment indicator variable
    - event: Event type ('death' or 'progression')
    - items_list: Feature list for IPTW 
    - numerical_features: List of numerical features
    - rmst_time: Time to calculate RMST 

    Returns:
    - mos_A_95: mOS 95% CI for treatment
    - mos_B_95: mOS 95% CI for control
    - rmst_A_95: RMST 95% CI for treatment
    - rmst_B_95: RMST 95% CI for control
    - difference_rmst_95: RMST 95% CI for difference between treatment and control 
    """
    
    np.random.seed(42)
    mos_A = []
    mos_B = []
    rmst_A_list = []
    rmst_B_list = []
    differences_rmst = []
    
    # Define variables based on the event type
    if event == 'death':
        time_column = 'timerisk_treatment'
        status_column = 'death_status'
        
    else:
        time_column = 'time_prog_treatment'
        status_column = 'pfs_status'
        
    # Set up preprocessor for logistical regression which will be for IPTW  
    numerical_transformer = Pipeline(steps = [
        ('imputer', SimpleImputer(strategy = 'median')),
        ('std_scaler', StandardScaler())])
        
    categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')
    categorical_features = list(df.select_dtypes(include = ['category']).columns)
        
    preprocessor = ColumnTransformer(
        transformers = [
            ('num', numerical_transformer, numerical_features),
            ('cat', categorical_transformer, categorical_features)],
        remainder = 'passthrough')
    
    # Boostrap HR 
    for _ in range(num_samples):
        
        # Resample data with replacement
        resampled_df = resample(df).drop(columns = ['ps', 'weight'])
        
        # Calculated IPTW for the resampled group 
        df_x = preprocessor.fit_transform(resampled_df.filter(items = items_list))
                                           
        df_lr = LogisticRegression(max_iter = 1000)
        df_lr.fit(df_x, resampled_df[drug])
        warnings.filterwarnings("ignore")
                                           
        pred = df_lr.predict_proba(df_x)        
        resampled_df['ps'] = pred[:, 1]                          
        resampled_df['weight'] = (
                np.where(resampled_df[drug] == 1, 1/resampled_df['ps'], 1/(1 - resampled_df['ps'])))
    
        # mOS from IPTW-KM
        kmf_A = KaplanMeierFitter()
        kmf_A.fit(resampled_df.query(f'{drug} == 1')[time_column]/30,
                  resampled_df.query(f'{drug} == 1')[status_column], 
                  weights = resampled_df.query(f'{drug} == 1')['weight'])

        kmf_B = KaplanMeierFitter()
        kmf_B.fit(resampled_df.query(f'{drug} == 0')[time_column]/30,
                  resampled_df.query(f'{drug} == 0')[status_column], 
                  weights = resampled_df.query(f'{drug} == 0')['weight'])
    
        mos_A.append(kmf_A.median_survival_time_)
        mos_B.append(kmf_B.median_survival_time_)
        
        # RMST from IPTW-KM
        rmst_A = restricted_mean_survival_time(kmf_A, rmst_time)
        rmst_B = restricted_mean_survival_time(kmf_B, rmst_time)
        
        rmst_A_list.append(rmst_A)
        rmst_B_list.append(rmst_B)
        differences_rmst.append(rmst_A - rmst_B)

    # Calculate the 95% confidence interval
    results = pd.Series({
    'mos_A_95': np.percentile(mos_A, [2.5, 97.5]),
    'mos_B_95': np.percentile(mos_B, [2.5, 97.5]),
    'rmst_A_95': np.percentile(rmst_A_list, [2.5, 97.5]),
    'rmst_B_95': np.percentile(rmst_B_list, [2.5, 97.5]),
    'difference_rmst_95': np.percentile(differences_rmst, [2.5, 97.5])
    })
    
    return results

In [37]:
cutoff = pd.read_csv('risk_cutoff_prostate.csv', index_col = 0)

### CHAARTED: docetaxel vs. ADT in metastatic, castration-sensitive prostate cancer  

**INCLUSION**
* Untreated metastatic prostate cancer, except up to 4 months of ADT 
* Castration-sensitive
* Received ADT or docetaxel plus ADT
* No active cardiac disease in the year preceding metastatic diagnosis 
* ECOG is not 3 or 4 at time of metastatic diagnosis 
* Adequate organ function at time of mestastatic diagnosis  

#### ADT

In [38]:
df_full = pd.read_csv('df_risk_crude.csv', index_col = 'PatientID', dtype = {'death_status': bool})
df_full.index.nunique()

18927

In [39]:
df_full.reset_index(inplace = True)

In [40]:
adt = pd.read_csv('Enhanced_MetPC_ADT.csv')

In [41]:
adt = (
    adt[adt['PatientID'].isin(df_full['PatientID'])]
    .query('TreatmentSetting == "Advanced"')
)

In [42]:
row_ID(adt)

(17863, 17863)

In [43]:
adt.loc[:, 'StartDate'] = pd.to_datetime(adt['StartDate'])

In [44]:
adt = adt.rename(columns = {'StartDate': 'StartDate_adt'})

In [45]:
df_full = pd.merge(df_full, adt[['PatientID', 'StartDate_adt']], on = 'PatientID', how = 'left')

In [46]:
row_ID(df_full)

(18927, 18927)

In [47]:
enhanced_met = pd.read_csv('Enhanced_MetProstate.csv')

In [48]:
enhanced_met = enhanced_met[enhanced_met['PatientID'].isin(df_full['PatientID'])]

In [49]:
enhanced_met.loc[:, 'MetDiagnosisDate'] = pd.to_datetime(enhanced_met['MetDiagnosisDate'])

In [50]:
enhanced_met.loc[:, 'CRPCDate'] = pd.to_datetime(enhanced_met['CRPCDate'])

In [51]:
df_full = pd.merge(df_full, enhanced_met[['PatientID', 'MetDiagnosisDate', 'CRPCDate']], on = 'PatientID', how = 'left')

In [52]:
row_ID(df_full)

(18927, 18927)

In [53]:
# Find all that start ADT within -120 to +90 days of metastatic diagnosis 
chaarted_adt = (
    df_full
    .assign(adt_diff = (df_full['StartDate_adt'] - df_full['MetDiagnosisDate']).dt.days)
    .query('adt_diff >= -120 and adt_diff <= 90')
)

In [54]:
# Find all that have missing CPRC date or date is >90 after metastatic diagnosis 
chaarted_adt = (
    chaarted_adt
    .assign(crpc_diff = (chaarted_adt['CRPCDate'] - chaarted_adt['MetDiagnosisDate']).dt.days)
    .query('crpc_diff > 90 or CRPCDate.isna()', engine = 'python')
)

In [55]:
row_ID(chaarted_adt)

(10475, 10475)

In [56]:
line_therapy = pd.read_csv('LineOfTherapy.csv')

In [57]:
zero = (
    line_therapy.query('LineNumber == 0')
    .PatientID
)

In [58]:
# Exclude patients with missing treatment information (ie, LineNumber == 0)
chaarted_adt = chaarted_adt[~chaarted_adt.PatientID.isin(zero)]

In [59]:
row_ID(chaarted_adt)

(10117, 10117)

In [60]:
line_therapy_cont = line_therapy.query('LineSetting != "nmCRPC"')

In [61]:
# List of FDA approved drugs for mPC as of July 2023. Clinical study drug is also included. 
fda_yes = [
    'Abiraterone',
    'Apalutamide',
    'Cabazitaxel',
    'Carboplatin',
    'Cisplatin',
    'Darolutamide',
    'Docetaxel',
    'Enzalutamide',
    'Mitoxantrone',
    'Olaparib',
    'Oxaliplatin',
    'Paclitaxel',
    'Pembrolizumab',
    'Radium-223',
    'Rucaparib',
    'Sipuleucel-T',
    'Clinical Study Drug'
]

In [62]:
line_therapy_cont = line_therapy_cont[line_therapy_cont['LineName'].str.contains('|'.join(fda_yes))]

In [63]:
line_therapy_cont = (
    line_therapy_cont
    .sort_values(by = ['PatientID', 'StartDate'], ascending = [True, True])
)

In [64]:
line_therapy_cont['line_number'] = (
    line_therapy_cont.groupby('PatientID')['LineNumber'].cumcount()+1
)

In [65]:
# First line therapy is in castrate-resistant setting 
fl_crpc = (
    line_therapy_cont[line_therapy_cont.PatientID.isin(chaarted_adt.PatientID)]
    .query('line_number == 1 & LineSetting == "mCRPC"')
    .PatientID
)

In [66]:
# Never received therapy other than ADT
notrt_adt = (
    chaarted_adt[~chaarted_adt.PatientID.isin(line_therapy_cont.PatientID)]
    .PatientID
)

In [67]:
adt_IDs = np.concatenate((fl_crpc, notrt_adt))

In [68]:
len(adt_IDs)

6218

In [69]:
chaarted_adt = chaarted_adt[chaarted_adt.PatientID.isin(adt_IDs)]

In [70]:
chaarted_adt.loc[:,'adt_dotx'] = 0

In [71]:
row_ID(chaarted_adt)

(6218, 6218)

In [72]:
chaarted_adt.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,other_met,prim_treatment,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,adt_dotx
2193,F92E6E2301977,M,white,unknown,77,COMMUNITY,Unknown / Not documented,Unknown / Not documented,"Prostate cancer, NOS",Unknown / Not documented,...,0.0,prostatectomy,1.0,0.248337,2017-06-30,2017-05-24,2019-02-06,37.0,623.0,0
4919,FB161DF16921F,M,other,unknown,65,COMMUNITY,Unknown / Not documented,M0,Adenocarcinoma,Unknown / Not documented,...,0.0,prostatectomy,1.0,0.807494,2020-02-12,2020-01-25,2020-11-03,18.0,283.0,0
8210,F11936921F2EA,M,other,unknown,72,COMMUNITY,N0,M0,Adenocarcinoma,4 + 3 = 7,...,0.0,radiation,0.0,-0.579715,2020-09-10,2020-08-10,NaT,31.0,,0


#### Docetaxel + ADT

In [73]:
# Find those that start ADT within -120 to 90 days of metastaic diagnosis 
chaarted_dotx = (
    df_full
    .assign(adt_diff = (df_full['StartDate_adt'] - df_full['MetDiagnosisDate']).dt.days)
    .query('adt_diff >= -120 and adt_diff <= 90')
)

In [74]:
# Find all that have missing CPRC date or date is >90 days metastatic diagnosis 
chaarted_dotx = (
    chaarted_dotx
    .assign(crpc_diff = (chaarted_dotx['CRPCDate'] - chaarted_dotx['MetDiagnosisDate']).dt.days)
    .query('crpc_diff > 90 or CRPCDate.isna()', engine = 'python')
)

In [75]:
row_ID(chaarted_dotx)

(10475, 10475)

In [76]:
# Find start time of first line of mHSPC therapy. 
line_therapy_fl = (
    line_therapy[line_therapy['PatientID'].isin(chaarted_dotx['PatientID'])]
    .query('LineSetting == "mHSPC"')
    .sort_values(by = ['PatientID', 'StartDate'], ascending = [True, True])
    .drop_duplicates(subset = ['PatientID'], keep = 'first')
    .rename(columns = {'StartDate': 'StartDate_dotx'})
)

In [77]:
row_ID(line_therapy_fl)

(4042, 4042)

In [78]:
line_therapy_fl.loc[:, 'StartDate_dotx'] = pd.to_datetime(line_therapy_fl['StartDate_dotx'])

In [79]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('Docetaxel')].LineName.value_counts().head(10)

Docetaxel                                         1363
Abiraterone,Docetaxel                               27
Carboplatin,Docetaxel                               12
Clinical Study Drug,Docetaxel                       11
Darolutamide,Docetaxel                               9
Docetaxel,Enzalutamide                               6
Carboplatin,Docetaxel,Estramustine                   5
Docetaxel,Ketoconazole                               3
Apalutamide,Docetaxel                                2
Carboplatin,Docetaxel,Estramustine,Thalidomide       2
Name: LineName, dtype: int64

In [80]:
line_dotx = line_therapy_fl.query('LineName == "Docetaxel"')

In [81]:
row_ID(line_dotx)

(1363, 1363)

In [82]:
chaarted_dotx = pd.merge(chaarted_dotx, line_dotx[['PatientID', 'StartDate_dotx']], on = 'PatientID', how = 'left')

In [83]:
row_ID(chaarted_dotx)

(10475, 10475)

In [84]:
# Find all patients that start enzalutamide within 180 days of metastatic diagnosis
chaarted_dotx = (
    chaarted_dotx
    .assign(fl_diff = (chaarted_dotx['StartDate_dotx'] - chaarted_dotx['StartDate_adt']).dt.days)
    .query('fl_diff >= -90 and fl_diff <= 90')
)

In [85]:
len(chaarted_dotx)

1179

In [86]:
chaarted_dotx.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,prim_treatment,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,StartDate_dotx,fl_diff
9840,F5FCAA2C38D7A,M,white,unknown,71,COMMUNITY,N1,M1,"Prostate cancer, NOS",Unknown / Not documented,...,unknown,0.0,0.290659,2017-05-01,2017-04-21,2018-11-14,10.0,572.0,2017-05-19,18.0
8582,F1469558A6279,M,other,unknown,69,COMMUNITY,N1,M1b,Adenocarcinoma,10,...,unknown,0.0,-0.580464,2019-03-08,2019-02-22,NaT,14.0,,2019-03-29,21.0
2574,F3C26499047F6,M,white,unknown,73,COMMUNITY,Unknown / Not documented,M1c,Adenocarcinoma,9,...,unknown,0.0,0.159134,2014-10-16,2014-10-16,2016-02-10,0.0,482.0,2014-11-12,27.0


In [87]:
chaarted_dotx.loc[:,'adt_dotx'] = 1

In [88]:
chaarted = pd.concat([chaarted_adt, chaarted_dotx], ignore_index = True)

In [89]:
row_ID(chaarted)

(7397, 7397)

In [90]:
chaarted.adt_dotx.value_counts(dropna = False)

0    6218
1    1179
Name: adt_dotx, dtype: int64

In [91]:
chaarted.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,adt_dotx,StartDate_dotx,fl_diff
4435,FB886C0E0E4FE,M,white,unknown,68,COMMUNITY,NX,M1,Adenocarcinoma,10,...,0.0,1.381747,2020-07-14,2020-06-03,NaT,41.0,,0,NaT,
2387,F45F299130180,M,white,unknown,68,COMMUNITY,Unknown / Not documented,M1,"Prostate cancer, NOS",Unknown / Not documented,...,0.0,0.12568,2018-09-27,2018-09-06,2020-01-06,21.0,487.0,0,NaT,
5779,F501F16E7DFE2,M,white,unknown,57,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,Unknown / Not documented,...,0.0,0.837091,2018-05-23,2018-05-06,2019-08-22,17.0,473.0,0,NaT,


#### Time from ADT treatment to death or censor 

In [92]:
mortality_tr = pd.read_csv('mortality_cleaned_tr.csv')

In [93]:
mortality_te = pd.read_csv('mortality_cleaned_te.csv')

In [94]:
mortality_tr = mortality_tr[['PatientID', 'death_date', 'last_activity']]

In [95]:
mortality_te = mortality_te[['PatientID', 'death_date', 'last_activity']]

In [96]:
mortality = pd.concat([mortality_tr, mortality_te], ignore_index = True)
row_ID(mortality)

(18927, 18927)

In [97]:
mortality.loc[:, 'last_activity'] = pd.to_datetime(mortality['last_activity'])

In [98]:
mortality.loc[:, 'death_date'] = pd.to_datetime(mortality['death_date'])

In [99]:
len(mortality)

18927

In [100]:
chaarted = pd.merge(chaarted, mortality, on = 'PatientID', how = 'left')

In [101]:
row_ID(chaarted)

(7397, 7397)

In [102]:
conditions = [
    (chaarted['death_status'] == 1),
    (chaarted['death_status'] == 0)]

choices = [
    (chaarted['death_date'] - chaarted['StartDate_adt']).dt.days,
    (chaarted['last_activity'] - chaarted['StartDate_adt']).dt.days]

chaarted.loc[:, 'timerisk_treatment'] = np.select(conditions, choices)

chaarted = chaarted.query('timerisk_treatment >= 0')

#### Patient count 

In [103]:
row_ID(chaarted)

(7397, 7397)

In [104]:
# Exclude those with active cardiac disease in the year preceding metastatic diagnosis 
chaarted = chaarted[~chaarted['PatientID'].isin(cardiac_IDs)]

In [105]:
# Exclude those with ECOG 3 or 4 at time of metastatic diagnosis 
chaarted = chaarted.query('ecog_diagnosis != "3.0" and ecog_diagnosis != "4.0"')

In [106]:
# Exclude those with abnormal organ function at time of metastatic diagnosis 
chaarted = (
    chaarted
    .query('creatinine_diag < 2 or creatinine_diag_na == 1')
    .query('hemoglobin_diag > 9 or hemoglobin_diag_na == 1')
    .query('total_bilirubin_diag < 3 or total_bilirubin_diag_na == 1')
)

In [107]:
row_ID(chaarted)

(6812, 6812)

In [108]:
low_cutoff_chaarted = cutoff.loc['chaarted'].low

In [109]:
high_cutoff_chaarted = cutoff.loc['chaarted'].high

In [110]:
print('Docetaxel + ADT:',  chaarted.query('adt_dotx == 1').shape[0])
print('High risk:', chaarted.query('adt_dotx == 1').query('risk_score >= @high_cutoff_chaarted').shape[0])
print('Med risk:', chaarted.query('adt_dotx == 1').query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').shape[0])
print('Low risk:', chaarted.query('adt_dotx == 1').query('risk_score <= @low_cutoff_chaarted').shape[0])

Docetaxel + ADT: 1088
High risk: 240
Med risk: 410
Low risk: 438


In [111]:
print('ADT:',  chaarted.query('adt_dotx == 0').shape[0])
print('High risk:', chaarted.query('adt_dotx == 0').query('risk_score >= @high_cutoff_chaarted').shape[0])
print('Med risk:', chaarted.query('adt_dotx == 0').query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').shape[0])
print('Low risk:', chaarted.query('adt_dotx == 0').query('risk_score <= @low_cutoff_chaarted').shape[0])

ADT: 5724
High risk: 1708
Med risk: 2002
Low risk: 2014


#### Survival curves with covariate balancing 

In [112]:
chaarted = chaarted.set_index('PatientID')

In [113]:
chaarted['met_cat'] = pd.cut(chaarted['met_year'],
                             bins = [2010, 2015, float('inf')],
                             labels = ['11-15', '16-22'])

In [114]:
conditions = [
    ((chaarted['ecog_diagnosis'] == "1.0") | (chaarted['ecog_diagnosis'] == "0.0")),  
    ((chaarted['ecog_diagnosis'] == "2.0") | (chaarted['ecog_diagnosis'] == "3.0"))
]

choices = ['lt_2', 'gte_2']

chaarted['ecog_2'] = np.select(conditions, choices, default = 'unknown')

In [115]:
chaarted_iptw = chaarted.filter(items = ['death_status',
                                         'timerisk_treatment',
                                         'adt_dotx',
                                         'age',
                                         'race',
                                         'p_type',
                                         'met_cat',
                                         'delta_met_diagnosis',
                                         'commercial',
                                         'medicare',
                                         'medicaid',
                                         'ecog_2',
                                         'prim_treatment',
                                         'PSAMetDiagnosis',
                                         'albumin_diag', 
                                         'weight_pct_change',
                                         'risk_score'])

In [116]:
chaarted_iptw.dtypes

death_status               bool
timerisk_treatment      float64
adt_dotx                  int64
age                       int64
race                     object
p_type                   object
met_cat                category
delta_met_diagnosis       int64
commercial              float64
medicare                float64
medicaid                float64
ecog_2                   object
prim_treatment           object
PSAMetDiagnosis         float64
albumin_diag            float64
weight_pct_change       float64
risk_score              float64
dtype: object

In [117]:
to_be_categorical = list(chaarted_iptw.select_dtypes(include = ['object']).columns)

In [118]:
to_be_categorical

['race', 'p_type', 'ecog_2', 'prim_treatment']

In [119]:
to_be_categorical.append('met_cat')

In [120]:
# Convert variables in list to categorical.
for x in list(to_be_categorical):
    chaarted_iptw[x] = chaarted_iptw[x].astype('category')

In [121]:
# List of numeric variables, excluding binary variables. 
numerical_features = ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score']

# Transformer will first calculate column median and impute, and then apply a standard scaler. 
numerical_transformer = Pipeline(steps = [
    ('imputer', SimpleImputer(strategy = 'median')),
    ('std_scaler', StandardScaler())])

In [122]:
# List of categorical features.
categorical_features = list(chaarted_iptw.select_dtypes(include = ['category']).columns)

# One-hot-encode categorical features.
categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')

In [123]:
preprocessor = ColumnTransformer(
    transformers = [
        ('num', numerical_transformer, numerical_features),
        ('cat', categorical_transformer, categorical_features)],
    remainder = 'passthrough')

In [124]:
chaarted_iptw_low = (
    chaarted_iptw
    .query('risk_score <= @low_cutoff_chaarted'))

chaarted_iptw_med = (
    chaarted_iptw
    .query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted'))

chaarted_iptw_high = (
    chaarted_iptw
    .query('risk_score >= @high_cutoff_chaarted'))

chaarted_iptw_all = chaarted_iptw

In [125]:
chaarted_low_x = preprocessor.fit_transform(chaarted_iptw_low.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'delta_met_diagnosis',
                                                                              'met_cat',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid',
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag', 
                                                                              'weight_pct_change',
                                                                              'risk_score']))

chaarted_med_x = preprocessor.fit_transform(chaarted_iptw_med.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'delta_met_diagnosis',
                                                                              'met_cat',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid',
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag', 
                                                                              'weight_pct_change', 
                                                                              'risk_score']))

chaarted_high_x = preprocessor.fit_transform(chaarted_iptw_high.filter(items = ['age',
                                                                                'race',
                                                                                'p_type',
                                                                                'delta_met_diagnosis',
                                                                                'met_cat',
                                                                                'commercial',
                                                                                'medicare',
                                                                                'medicaid',
                                                                                'ecog_2',
                                                                                'prim_treatment',
                                                                                'PSAMetDiagnosis', 
                                                                                'albumin_diag', 
                                                                                'weight_pct_change',
                                                                                'risk_score']))

chaarted_all_x = preprocessor.fit_transform(chaarted_iptw_all.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'delta_met_diagnosis',
                                                                              'met_cat',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid',
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag', 
                                                                              'weight_pct_change', 
                                                                              'risk_score']))

In [126]:
lr_chaarted_low = LogisticRegression(max_iter = 1000)
lr_chaarted_low.fit(chaarted_low_x, chaarted_iptw_low['adt_dotx'])

LogisticRegression(max_iter=1000)

In [127]:
lr_chaarted_med = LogisticRegression(max_iter = 1000)
lr_chaarted_med.fit(chaarted_med_x, chaarted_iptw_med['adt_dotx'])

LogisticRegression(max_iter=1000)

In [128]:
lr_chaarted_high = LogisticRegression(max_iter = 1000)
lr_chaarted_high.fit(chaarted_high_x, chaarted_iptw_high['adt_dotx'])

LogisticRegression(max_iter=1000)

In [129]:
lr_chaarted_all = LogisticRegression(max_iter = 1000)
lr_chaarted_all.fit(chaarted_all_x, chaarted_iptw_all['adt_dotx'])

LogisticRegression(max_iter=1000)

In [130]:
pred_low = lr_chaarted_low.predict_proba(chaarted_low_x)
pred_med = lr_chaarted_med.predict_proba(chaarted_med_x)
pred_high = lr_chaarted_high.predict_proba(chaarted_high_x)
pred_all = lr_chaarted_all.predict_proba(chaarted_all_x)

In [131]:
chaarted_iptw_low['ps'] = pred_low[:, 1]
chaarted_iptw_med['ps'] = pred_med[:, 1]
chaarted_iptw_high['ps'] = pred_high[:, 1]
chaarted_iptw_all['ps'] = pred_all[:, 1]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [132]:
chaarted_iptw_low['weight'] = (
    np.where(chaarted_iptw_low['adt_dotx'] == 1, 1/chaarted_iptw_low['ps'], 1/(1 - chaarted_iptw_low['ps'])))

chaarted_iptw_med['weight'] = (
    np.where(chaarted_iptw_med['adt_dotx'] == 1, 1/chaarted_iptw_med['ps'], 1/(1 - chaarted_iptw_med['ps'])))

chaarted_iptw_high['weight'] = (
    np.where(chaarted_iptw_high['adt_dotx'] == 1, 1/chaarted_iptw_high['ps'], 1/(1 - chaarted_iptw_high['ps'])))

chaarted_iptw_all['weight'] = (
    np.where(chaarted_iptw_all['adt_dotx'] == 1, 1/chaarted_iptw_all['ps'], 1/(1 - chaarted_iptw_all['ps'])))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [133]:
# Low KM curves
kmf_low_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_low_adt_chaarted_iptw = KaplanMeierFitter()

kmf_low_dotx_chaarted_iptw.fit(
    chaarted_iptw_low.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_low.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_low.query('adt_dotx == 1')['weight'])

kmf_low_adt_chaarted_iptw.fit(
    chaarted_iptw_low.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_low.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_low.query('adt_dotx == 0')['weight'])

# Med KM curves
kmf_med_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_med_adt_chaarted_iptw = KaplanMeierFitter()

kmf_med_dotx_chaarted_iptw.fit(
    chaarted_iptw_med.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_med.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_med.query('adt_dotx == 1')['weight'])

kmf_med_adt_chaarted_iptw.fit(
    chaarted_iptw_med.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_med.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_med.query('adt_dotx == 0')['weight'])

# High KM curves 
kmf_high_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_high_adt_chaarted_iptw = KaplanMeierFitter()

kmf_high_dotx_chaarted_iptw.fit(
    chaarted_iptw_high.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_high.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_high.query('adt_dotx == 1')['weight'])

kmf_high_adt_chaarted_iptw.fit(
    chaarted_iptw_high.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_high.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_high.query('adt_dotx == 0')['weight'])

# All KM curves 
kmf_all_dotx_chaarted_iptw = KaplanMeierFitter()
kmf_all_adt_chaarted_iptw = KaplanMeierFitter()

kmf_all_dotx_chaarted_iptw.fit(
    chaarted_iptw_all.query('adt_dotx == 1').timerisk_treatment/30,
    chaarted_iptw_all.query('adt_dotx == 1').death_status,
    weights = chaarted_iptw_all.query('adt_dotx == 1')['weight'])

kmf_all_adt_chaarted_iptw.fit(
    chaarted_iptw_all.query('adt_dotx == 0').timerisk_treatment/30,
    chaarted_iptw_all.query('adt_dotx == 0').death_status,
    weights = chaarted_iptw_all.query('adt_dotx == 0')['weight'])

  It's important to know that the naive variance estimates of the coefficients are biased. Instead use Monte Carlo to
  estimate the variances. See paper "Variance estimation when using inverse probability of treatment weighting (IPTW) with survival analysis"
  or "Adjusted Kaplan-Meier estimator and log-rank test with inverse probability of treatment weighting for survival data."
                  


<lifelines.KaplanMeierFitter:"KM_estimate", fitted with 6830.03 total observations, 3154.94 right-censored observations>

#### Calculating survival metrics 

In [134]:
dotx_chaarted_median_os = mos(kmf_low_dotx_chaarted_iptw,
                              kmf_med_dotx_chaarted_iptw,
                              kmf_high_dotx_chaarted_iptw,
                              kmf_all_dotx_chaarted_iptw)

adt_chaarted_median_os = mos(kmf_low_adt_chaarted_iptw,
                             kmf_med_adt_chaarted_iptw,
                             kmf_high_adt_chaarted_iptw,
                             kmf_all_adt_chaarted_iptw)

In [135]:
chaarted_iptw_all_imputed = chaarted_iptw_all.copy()
chaarted_iptw_all_imputed['albumin_diag'] = chaarted_iptw_all_imputed['albumin_diag'].fillna(chaarted_iptw_all_imputed['albumin_diag'].median())
chaarted_iptw_all_imputed['weight_pct_change'] = chaarted_iptw_all_imputed['weight_pct_change'].fillna(chaarted_iptw_all_imputed['weight_pct_change'].median())
chaarted_iptw_all_imputed['PSAMetDiagnosis'] = chaarted_iptw_all_imputed['PSAMetDiagnosis'].fillna(chaarted_iptw_all_imputed['PSAMetDiagnosis'].median())

In [136]:
chaarted_hr_all = CoxPHFitter()
chaarted_hr_all.fit(chaarted_iptw_all_imputed,
                    duration_col = 'timerisk_treatment',
                    event_col = 'death_status',
                    formula = 'adt_dotx + age + race + p_type + delta_met_diagnosis + met_cat + commercial + medicare + medicaid + ecog_2 + prim_treatment + PSAMetDiagnosis + albumin_diag + weight_pct_change + risk_score',
                    weights_col = 'weight', 
                    robust = True)

<lifelines.CoxPHFitter: fitted with 13604.6 total observations, 6749.91 right-censored observations>

In [137]:
chaarted_all_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_all,
                                         1000,
                                         'adt_dotx',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [138]:
chaarted_low_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_low,
                                         1000,
                                         'adt_dotx',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [139]:
chaarted_med_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_med,
                                         1000,
                                         'adt_dotx',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [140]:
chaarted_high_rmst_mos_95 = rmst_mos_95ci(chaarted_iptw_high,
                                          1000,
                                          'adt_dotx',
                                          'death',
                                          ['age',
                                           'race',
                                           'p_type',
                                           'delta_met_diagnosis',
                                           'met_cat',
                                           'commercial',
                                           'medicare',
                                           'medicaid',
                                           'ecog_2',
                                           'prim_treatment',
                                           'PSAMetDiagnosis', 
                                           'albumin_diag', 
                                           'weight_pct_change',
                                           'risk_score'],
                                          ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                          60)

In [141]:
chaarted_data = [
    {'trial_name': 'CHAARTED', 
     'risk_group': 'low', 
     's_trt_mos': dotx_chaarted_median_os[0],
     's_trt_mos_95': chaarted_low_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_chaarted_median_os[0],
     's_cont_mos_95': chaarted_low_rmst_mos_95.mos_B_95,
     's_mos_diff': dotx_chaarted_median_os[0] - adt_chaarted_median_os[0], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     's_trt_rmst': restricted_mean_survival_time(kmf_low_dotx_chaarted_iptw, 60),
     's_trt_rmst_95': chaarted_low_rmst_mos_95.rmst_A_95,
     's_cont_rmst': restricted_mean_survival_time(kmf_low_adt_chaarted_iptw, 60),
     's_cont_rmst_95': chaarted_low_rmst_mos_95.rmst_B_95,
     's_diff_rmst': restricted_mean_survival_time(kmf_low_dotx_chaarted_iptw, 60) - restricted_mean_survival_time(kmf_low_adt_chaarted_iptw, 60),
     's_diff_rmst_95': chaarted_low_rmst_mos_95.difference_rmst_95,
     'scount': chaarted.query('risk_score <= @low_cutoff_chaarted').shape[0]},
    
    {'trial_name': 'CHAARTED', 
     'risk_group': 'medium', 
     's_trt_mos': dotx_chaarted_median_os[1],
     's_trt_mos_95': chaarted_med_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_chaarted_median_os[1],
     's_cont_mos_95': chaarted_med_rmst_mos_95.mos_B_95,
     's_mos_diff': dotx_chaarted_median_os[1] - adt_chaarted_median_os[1], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     's_trt_rmst': restricted_mean_survival_time(kmf_med_dotx_chaarted_iptw, 60),
     's_trt_rmst_95': chaarted_med_rmst_mos_95.rmst_A_95,
     's_cont_rmst': restricted_mean_survival_time(kmf_med_adt_chaarted_iptw, 60),
     's_cont_rmst_95': chaarted_med_rmst_mos_95.rmst_B_95,
     's_diff_rmst': restricted_mean_survival_time(kmf_med_dotx_chaarted_iptw, 60) - restricted_mean_survival_time(kmf_med_adt_chaarted_iptw, 60),
     's_diff_rmst_95': chaarted_med_rmst_mos_95.difference_rmst_95,
     'scount': chaarted.query('risk_score < @high_cutoff_chaarted and risk_score > @low_cutoff_chaarted').shape[0]},
    
    {'trial_name': 'CHAARTED', 
     'risk_group': 'high', 
     's_trt_mos': dotx_chaarted_median_os[2],
     's_trt_mos_95': chaarted_high_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_chaarted_median_os[2],
     's_cont_mos_95': chaarted_high_rmst_mos_95.mos_B_95,
     's_mos_diff': dotx_chaarted_median_os[2] - adt_chaarted_median_os[2], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     's_trt_rmst': restricted_mean_survival_time(kmf_high_dotx_chaarted_iptw, 60),
     's_trt_rmst_95': chaarted_high_rmst_mos_95.rmst_A_95,
     's_cont_rmst': restricted_mean_survival_time(kmf_high_adt_chaarted_iptw, 60),
     's_cont_rmst_95': chaarted_high_rmst_mos_95.rmst_B_95,
     's_diff_rmst': restricted_mean_survival_time(kmf_high_dotx_chaarted_iptw, 60) - restricted_mean_survival_time(kmf_high_adt_chaarted_iptw, 60),
     's_diff_rmst_95': chaarted_high_rmst_mos_95.difference_rmst_95,
     'scount': chaarted.query('risk_score >= @high_cutoff_chaarted').shape[0]},
    
    {'trial_name': 'CHAARTED', 
     'risk_group': 'all', 
     's_hr': chaarted_hr_all.hazard_ratios_['adt_dotx'],
     's_hr_95': [chaarted_hr_all.summary.loc['adt_dotx']['exp(coef) lower 95%'], chaarted_hr_all.summary.loc['adt_dotx']['exp(coef) upper 95%']],
     's_trt_mos': dotx_chaarted_median_os[3],
     's_trt_mos_95': chaarted_all_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_chaarted_median_os[3],
     's_cont_mos_95': chaarted_all_rmst_mos_95.mos_B_95,
     's_mos_diff': dotx_chaarted_median_os[3] - adt_chaarted_median_os[3], 
     'rct_trt_arm': 57.6,
     'rct_cont_arm': 44.0,
     'rct_mos_diff': 57.6-44.0,
     'scount': chaarted.shape[0]}
]

### LATITUDE : abiraterone vs. ADT in metastatic, castration-sensitive prostate cancer  

**INCLUSION**
* Untreated metastatic prostate cancer, except up to 3 months of ADT 
* Castration-sensitive
* Received ADT or abiraterone plus ADT
* No active cardiac disease, viral hepatitis, or chronic liver disease in the year preceding metastatic diagnosis
* No CNS metatastasis at time of metastatic diagnosis 
* ECOG is not 3 or 4 at time of metastatic diagnosis 
* Adequate organ function at time of mestastatic diagnosis 

#### ADT

In [142]:
df_full = pd.read_csv('df_risk_crude.csv', index_col = 'PatientID', dtype = {'death_status': bool})
df_full.index.nunique()

18927

In [143]:
df_full.reset_index(inplace = True)

In [144]:
adt = pd.read_csv('Enhanced_MetPC_ADT.csv')

In [145]:
adt = (
    adt[adt['PatientID'].isin(df_full['PatientID'])]
    .query('TreatmentSetting == "Advanced"')
)

In [146]:
row_ID(adt)

(17863, 17863)

In [147]:
adt.loc[:, 'StartDate'] = pd.to_datetime(adt['StartDate'])

In [148]:
adt = adt.rename(columns = {'StartDate': 'StartDate_adt'})

In [149]:
df_full = pd.merge(df_full, adt[['PatientID', 'StartDate_adt']], on = 'PatientID', how = 'left')

In [150]:
row_ID(df_full)

(18927, 18927)

In [151]:
enhanced_met = pd.read_csv('Enhanced_MetProstate.csv')

In [152]:
enhanced_met = enhanced_met[enhanced_met['PatientID'].isin(df_full['PatientID'])]

In [153]:
enhanced_met.loc[:, 'MetDiagnosisDate'] = pd.to_datetime(enhanced_met['MetDiagnosisDate'])

In [154]:
enhanced_met.loc[:, 'CRPCDate'] = pd.to_datetime(enhanced_met['CRPCDate'])

In [155]:
df_full = pd.merge(df_full, enhanced_met[['PatientID', 'MetDiagnosisDate', 'CRPCDate']], on = 'PatientID', how = 'left')

In [156]:
row_ID(df_full)

(18927, 18927)

In [157]:
# Find all that start ADT within -90 to +90 days of metastatic diagnosis 
latitude_adt = (
    df_full
    .assign(adt_diff = (df_full['StartDate_adt'] - df_full['MetDiagnosisDate']).dt.days)
    .query('adt_diff >= -90 and adt_diff <= 90')
)

In [158]:
# Find all that have missing CPRC date or date is >90 after metastatic diagnosis 
latitude_adt = (
    latitude_adt
    .assign(crpc_diff = (latitude_adt['CRPCDate'] - latitude_adt['MetDiagnosisDate']).dt.days)
    .query('crpc_diff > 90 or CRPCDate.isna()', engine = 'python')
)

In [159]:
row_ID(latitude_adt)

(10397, 10397)

In [160]:
line_therapy = pd.read_csv('LineOfTherapy.csv')

In [161]:
zero = (
    line_therapy.query('LineNumber == 0')
    .PatientID
)

In [162]:
# Exclude patients with missing treatment information (ie, LineNumber == 0)
latitude_adt = latitude_adt[~latitude_adt.PatientID.isin(zero)]

In [163]:
row_ID(latitude_adt)

(10047, 10047)

In [164]:
line_therapy_cont = line_therapy.query('LineSetting != "nmCRPC"')

In [165]:
# List of FDA approved drugs for mPC as of July 2023. Clinical study drug is also included. 
fda_yes = [
    'Abiraterone',
    'Apalutamide',
    'Cabazitaxel',
    'Carboplatin',
    'Cisplatin',
    'Darolutamide',
    'Docetaxel',
    'Enzalutamide',
    'Mitoxantrone',
    'Olaparib',
    'Oxaliplatin',
    'Paclitaxel',
    'Pembrolizumab',
    'Radium-223',
    'Rucaparib',
    'Sipuleucel-T',
    'Clinical Study Drug'
]

In [166]:
line_therapy_cont = line_therapy_cont[line_therapy_cont['LineName'].str.contains('|'.join(fda_yes))]

In [167]:
line_therapy_cont = (
    line_therapy_cont
    .sort_values(by = ['PatientID', 'StartDate'], ascending = [True, True])
)

In [168]:
line_therapy_cont['line_number'] = (
    line_therapy_cont.groupby('PatientID')['LineNumber'].cumcount()+1
)

In [169]:
# First line therapy is in castrate-resistant setting 
fl_crpc = (
    line_therapy_cont[line_therapy_cont.PatientID.isin(latitude_adt.PatientID)]
    .query('line_number == 1 & LineSetting == "mCRPC"')
    .PatientID
)

In [170]:
# Never received therapy other than ADT
notrt_adt = (
    latitude_adt[~latitude_adt.PatientID.isin(line_therapy_cont.PatientID)]
    .PatientID
)

In [171]:
adt_IDs = np.concatenate((fl_crpc, notrt_adt))

In [172]:
len(adt_IDs)

6167

In [173]:
latitude_adt = latitude_adt[latitude_adt.PatientID.isin(adt_IDs)]

In [174]:
latitude_adt.loc[:,'adt_abi'] = 0

In [175]:
row_ID(latitude_adt)

(6167, 6167)

In [176]:
latitude_adt.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,other_met,prim_treatment,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,adt_abi
16908,FDBE5B7A39BFC,M,unknown,not_hispanic_latino,76,COMMUNITY,Unknown / Not documented,Unknown / Not documented,Adenocarcinoma,Less than or equal to 6,...,0.0,unknown,0.0,0.133486,2014-01-06,2013-12-26,NaT,11.0,,0
3801,F709FBBB2317F,M,other,unknown,66,BOTH,N1,M1c,Adenocarcinoma,10,...,0.0,unknown,0.0,0.231824,2019-10-10,2019-09-12,2020-06-19,28.0,281.0,0
5784,FA5828C6B919C,M,white,unknown,79,COMMUNITY,N0,M0,Adenocarcinoma,3 + 4 = 7,...,0.0,radiation,0.0,-0.771469,2016-06-01,2016-06-13,NaT,-12.0,,0


#### Treatment arm: Abiraterone + ADT

In [177]:
# Find those that start ADT within -90 to 90 days of metastaic diagnosis 
latitude_abi = (
    df_full
    .assign(adt_diff = (df_full['StartDate_adt'] - df_full['MetDiagnosisDate']).dt.days)
    .query('adt_diff >= -90 and adt_diff <= 90')
)

In [178]:
# Find all that have missing CPRC date or date is >30 days metastatic diagnosis 
latitude_abi = (
    latitude_abi
    .assign(crpc_diff = (latitude_abi['CRPCDate'] - latitude_abi['MetDiagnosisDate']).dt.days)
    .query('crpc_diff > 90 or CRPCDate.isna()', engine = 'python')
)

In [179]:
row_ID(latitude_abi)

(10397, 10397)

In [180]:
# Find start time of first line of mHSPC therapy. 
line_therapy_fl = (
    line_therapy[line_therapy['PatientID'].isin(latitude_abi['PatientID'])]
    .query('LineSetting == "mHSPC"')
    .sort_values(by = ['PatientID', 'StartDate'], ascending = [True, True])
    .drop_duplicates(subset = ['PatientID'], keep = 'first')
    .rename(columns = {'StartDate': 'StartDate_abi'})
)

In [181]:
row_ID(line_therapy_fl)

(4022, 4022)

In [182]:
line_therapy_fl.loc[:, 'StartDate_abi'] = pd.to_datetime(line_therapy_fl['StartDate_abi'])

In [183]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('Abiraterone')].LineName.value_counts().head(10)

Abiraterone                             1165
Abiraterone,Docetaxel                     27
Abiraterone,Enzalutamide                   8
Abiraterone,Clinical Study Drug            4
Abiraterone,Apalutamide                    4
Abiraterone,Apalutamide,Enzalutamide       1
Abiraterone,Capecitabine                   1
Abiraterone,Medroxyprogesterone            1
Abiraterone,Fluorouracil                   1
Abiraterone,Tamoxifen                      1
Name: LineName, dtype: int64

In [184]:
line_abi = line_therapy_fl.query('LineName == "Abiraterone"')

In [185]:
row_ID(line_abi)

(1165, 1165)

In [186]:
latitude_abi = pd.merge(latitude_abi, line_abi[['PatientID', 'StartDate_abi']], on = 'PatientID', how = 'left')

In [187]:
row_ID(latitude_abi)

(10397, 10397)

In [188]:
# Find all patients that start abiraterone within 90 days of ADT start
latitude_abi = (
    latitude_abi
    .assign(abi_diff = (latitude_abi['StartDate_abi'] - latitude_abi['StartDate_adt']).dt.days)
    .query('abi_diff >= -90 and abi_diff <= 90')
)

In [189]:
row_ID(latitude_abi)

(876, 876)

In [190]:
latitude_abi.sample(3)

Unnamed: 0,PatientID,Gender,race,ethnicity,age,p_type,NStage,MStage,Histology,GleasonScore,...,prim_treatment,early_adt,risk_score,StartDate_adt,MetDiagnosisDate,CRPCDate,adt_diff,crpc_diff,StartDate_abi,abi_diff
5626,FDCFBC42AE924,M,white,unknown,64,COMMUNITY,Unknown / Not documented,M1,Adenocarcinoma,4 + 3 = 7,...,unknown,0.0,-0.761958,2018-09-10,2018-08-16,NaT,25.0,,2018-09-20,10.0
3271,F61E53757C642,M,white,unknown,80,COMMUNITY,Unknown / Not documented,M1,"Prostate cancer, NOS",Unknown / Not documented,...,unknown,0.0,1.649394,2017-05-15,2017-05-15,NaT,0.0,,2017-05-25,10.0
3052,F7967DA73760A,M,black,unknown,71,ACADEMIC,Unknown / Not documented,Unknown / Not documented,"Prostate cancer, NOS",9,...,radiation,1.0,-0.435467,2021-09-01,2021-08-19,NaT,13.0,,2021-09-28,27.0


In [191]:
latitude_abi.loc[:,'adt_abi'] = 1

In [192]:
latitude = pd.concat([latitude_adt, latitude_abi], ignore_index = True)

In [193]:
row_ID(latitude)

(7043, 7043)

In [194]:
latitude.adt_abi.value_counts(dropna = False)

0    6167
1     876
Name: adt_abi, dtype: int64

#### Time from ADT treatment to death or censor 

In [195]:
mortality_tr = pd.read_csv('mortality_cleaned_tr.csv')

In [196]:
mortality_te = pd.read_csv('mortality_cleaned_te.csv')

In [197]:
mortality_tr = mortality_tr[['PatientID', 'death_date', 'last_activity']]

In [198]:
mortality_te = mortality_te[['PatientID', 'death_date', 'last_activity']]

In [199]:
mortality = pd.concat([mortality_tr, mortality_te], ignore_index = True)
row_ID(mortality)

(18927, 18927)

In [200]:
mortality.loc[:, 'last_activity'] = pd.to_datetime(mortality['last_activity'])

In [201]:
mortality.loc[:, 'death_date'] = pd.to_datetime(mortality['death_date'])

In [202]:
len(mortality)

18927

In [203]:
latitude = pd.merge(latitude, mortality, on = 'PatientID', how = 'left')

In [204]:
row_ID(latitude)

(7043, 7043)

In [205]:
conditions = [
    (latitude['death_status'] == 1),
    (latitude['death_status'] == 0)]

choices = [
    (latitude['death_date'] - latitude['StartDate_adt']).dt.days,
    (latitude['last_activity'] - latitude['StartDate_adt']).dt.days]

latitude.loc[:, 'timerisk_treatment'] = np.select(conditions, choices)

latitude = latitude.query('timerisk_treatment >= 0')

#### Patient count 

In [206]:
row_ID(latitude)

(7043, 7043)

In [207]:
# Exclude those with active cardiac disease in the year preceding metastatic diagnosis 
latitude = latitude[~latitude['PatientID'].isin(cardiac_IDs)]

In [208]:
# Exclude those with viral hepatitis or chronic liver disease in the year preceding metastatic diagnosis 
latitude = latitude[~latitude['PatientID'].isin(liv_IDs)]

In [209]:
# Exclude those with CNS metastasis at time of metastatic diagnosis 
latitude = latitude.query('cns_met == 0')

In [210]:
# Exclude those with ECOG of 3 or 4 at time of metastatic diagnosis 
latitude = latitude.query('ecog_diagnosis != "3.0" and ecog_diagnosis != "4.0"')

In [211]:
# Exclude those with abnormal organ function at time of metastatic diagnosis 
latitude = (
    latitude
    .query('creatinine_diag < 2 or creatinine_diag_na == 1')
    .query('hemoglobin_diag > 9 or hemoglobin_diag_na == 1')
    .query('total_bilirubin_diag < 3 or total_bilirubin_diag_na == 1')
)

In [212]:
row_ID(latitude)

(6418, 6418)

In [213]:
low_cutoff_latitude = cutoff.loc['latitude'].low

In [214]:
high_cutoff_latitude = cutoff.loc['latitude'].high

In [215]:
print('Abiraterone + ADT:',  latitude.query('adt_abi == 1').shape[0])
print('High risk:', latitude.query('adt_abi == 1').query('risk_score >= @high_cutoff_latitude').shape[0])
print('Med risk:', latitude.query('adt_abi == 1').query('risk_score < @high_cutoff_latitude and risk_score > @low_cutoff_latitude').shape[0])
print('Low risk:', latitude.query('adt_abi == 1').query('risk_score <= @low_cutoff_latitude').shape[0])

Abiraterone + ADT: 783
High risk: 204
Med risk: 245
Low risk: 334


In [216]:
print('ADT:',  latitude.query('adt_abi == 0').shape[0])
print('High risk:', latitude.query('adt_abi == 0').query('risk_score >= @high_cutoff_latitude').shape[0])
print('Med risk:', latitude.query('adt_abi == 0').query('risk_score < @high_cutoff_latitude and risk_score > @low_cutoff_latitude').shape[0])
print('Low risk:', latitude.query('adt_abi == 0').query('risk_score <= @low_cutoff_latitude').shape[0])

ADT: 5635
High risk: 1615
Med risk: 2035
Low risk: 1985


#### Survival curves with covariate balancing 

In [217]:
latitude = latitude.set_index('PatientID')

In [218]:
latitude['met_cat'] = pd.cut(latitude['met_year'],
                             bins = [2010, 2018, float('inf')],
                             labels = ['11-18', '19-22'])

In [219]:
conditions = [
    ((latitude['ecog_diagnosis'] == "1.0") | (latitude['ecog_diagnosis'] == "0.0")),  
    ((latitude['ecog_diagnosis'] == "2.0") | (latitude['ecog_diagnosis'] == "3.0"))
]

choices = ['lt_2', 'gte_2']

latitude['ecog_2'] = np.select(conditions, choices, default = 'unknown')

In [220]:
latitude_iptw = latitude.filter(items = ['death_status',
                                         'timerisk_treatment',
                                         'adt_abi',
                                         'age',
                                         'race',
                                         'p_type',
                                         'met_cat',
                                         'delta_met_diagnosis',
                                         'commercial',
                                         'medicare',
                                         'medicaid',
                                         'ecog_2',
                                         'prim_treatment',
                                         'PSAMetDiagnosis',
                                         'albumin_diag', 
                                         'weight_pct_change',
                                         'risk_score'])

In [221]:
latitude_iptw.dtypes

death_status               bool
timerisk_treatment      float64
adt_abi                   int64
age                       int64
race                     object
p_type                   object
met_cat                category
delta_met_diagnosis       int64
commercial              float64
medicare                float64
medicaid                float64
ecog_2                   object
prim_treatment           object
PSAMetDiagnosis         float64
albumin_diag            float64
weight_pct_change       float64
risk_score              float64
dtype: object

In [222]:
to_be_categorical = list(latitude_iptw.select_dtypes(include = ['object']).columns)

In [223]:
to_be_categorical

['race', 'p_type', 'ecog_2', 'prim_treatment']

In [224]:
to_be_categorical.append('met_cat')

In [225]:
# Convert variables in list to categorical.
for x in list(to_be_categorical):
    latitude_iptw[x] = latitude_iptw[x].astype('category')

In [226]:
# List of numeric variables, excluding binary variables. 
numerical_features = ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score']

# Transformer will first calculate column median and impute, and then apply a standard scaler. 
numerical_transformer = Pipeline(steps = [
    ('imputer', SimpleImputer(strategy = 'median')),
    ('std_scaler', StandardScaler())])

In [227]:
# List of categorical features.
categorical_features = list(latitude_iptw.select_dtypes(include = ['category']).columns)

# One-hot-encode categorical features.
categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')

In [228]:
preprocessor = ColumnTransformer(
    transformers = [
        ('num', numerical_transformer, numerical_features),
        ('cat', categorical_transformer, categorical_features)],
    remainder = 'passthrough')

In [229]:
latitude_iptw_low = (
    latitude_iptw
    .query('risk_score <= @low_cutoff_latitude'))

latitude_iptw_med = (
    latitude_iptw
    .query('risk_score < @high_cutoff_latitude and risk_score > @low_cutoff_latitude'))

latitude_iptw_high = (
    latitude_iptw
    .query('risk_score >= @high_cutoff_latitude'))

latitude_iptw_all = latitude_iptw

In [230]:
latitude_low_x = preprocessor.fit_transform(latitude_iptw_low.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'met_cat',
                                                                              'delta_met_diagnosis',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid', 
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag',
                                                                              'weight_pct_change',
                                                                              'risk_score']))

latitude_med_x = preprocessor.fit_transform(latitude_iptw_med.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'met_cat',
                                                                              'delta_met_diagnosis',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid', 
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag',
                                                                              'weight_pct_change',
                                                                              'risk_score']))

latitude_high_x = preprocessor.fit_transform(latitude_iptw_high.filter(items = ['age',
                                                                                'race',
                                                                                'p_type',
                                                                                'met_cat',
                                                                                'delta_met_diagnosis',
                                                                                'commercial',
                                                                                'medicare',
                                                                                'medicaid', 
                                                                                'ecog_2',
                                                                                'prim_treatment',
                                                                                'PSAMetDiagnosis', 
                                                                                'albumin_diag',
                                                                                'weight_pct_change', 
                                                                                'risk_score']))

latitude_all_x = preprocessor.fit_transform(latitude_iptw_all.filter(items = ['age',
                                                                              'race',
                                                                              'p_type',
                                                                              'met_cat',
                                                                              'delta_met_diagnosis',
                                                                              'commercial',
                                                                              'medicare',
                                                                              'medicaid', 
                                                                              'ecog_2',
                                                                              'prim_treatment',
                                                                              'PSAMetDiagnosis', 
                                                                              'albumin_diag',
                                                                              'weight_pct_change',
                                                                              'risk_score']))

In [231]:
lr_latitude_low = LogisticRegression(max_iter = 1000)
lr_latitude_low.fit(latitude_low_x, latitude_iptw_low['adt_abi'])

LogisticRegression(max_iter=1000)

In [232]:
lr_latitude_med = LogisticRegression(max_iter = 1000)
lr_latitude_med.fit(latitude_med_x, latitude_iptw_med['adt_abi'])

LogisticRegression(max_iter=1000)

In [233]:
lr_latitude_high = LogisticRegression(max_iter = 1000)
lr_latitude_high.fit(latitude_high_x, latitude_iptw_high['adt_abi'])

LogisticRegression(max_iter=1000)

In [234]:
lr_latitude_all = LogisticRegression(max_iter = 1000)
lr_latitude_all.fit(latitude_all_x, latitude_iptw_all['adt_abi'])

LogisticRegression(max_iter=1000)

In [235]:
pred_low = lr_latitude_low.predict_proba(latitude_low_x)
pred_med = lr_latitude_med.predict_proba(latitude_med_x)
pred_high = lr_latitude_high.predict_proba(latitude_high_x)
pred_all = lr_latitude_all.predict_proba(latitude_all_x)

In [236]:
latitude_iptw_low['ps'] = pred_low[:, 1]
latitude_iptw_med['ps'] = pred_med[:, 1]
latitude_iptw_high['ps'] = pred_high[:, 1]
latitude_iptw_all['ps'] = pred_all[:, 1]

In [237]:
latitude_iptw_low['weight'] = (
    np.where(latitude_iptw_low['adt_abi'] == 1, 1/latitude_iptw_low['ps'], 1/(1 - latitude_iptw_low['ps'])))

latitude_iptw_med['weight'] = (
    np.where(latitude_iptw_med['adt_abi'] == 1, 1/latitude_iptw_med['ps'], 1/(1 - latitude_iptw_med['ps'])))

latitude_iptw_high['weight'] = (
    np.where(latitude_iptw_high['adt_abi'] == 1, 1/latitude_iptw_high['ps'], 1/(1 - latitude_iptw_high['ps'])))

latitude_iptw_all['weight'] = (
    np.where(latitude_iptw_all['adt_abi'] == 1, 1/latitude_iptw_all['ps'], 1/(1 - latitude_iptw_all['ps'])))

In [238]:
# Low KM curves
kmf_low_abi_latitude_iptw = KaplanMeierFitter()
kmf_low_adt_latitude_iptw = KaplanMeierFitter()

kmf_low_abi_latitude_iptw.fit(
    latitude_iptw_low.query('adt_abi == 1').timerisk_treatment/30,
    latitude_iptw_low.query('adt_abi == 1').death_status,
    weights = latitude_iptw_low.query('adt_abi == 1')['weight'])

kmf_low_adt_latitude_iptw.fit(
    latitude_iptw_low.query('adt_abi == 0').timerisk_treatment/30,
    latitude_iptw_low.query('adt_abi == 0').death_status,
    weights = latitude_iptw_low.query('adt_abi == 0')['weight'])

# Med KM curves
kmf_med_abi_latitude_iptw = KaplanMeierFitter()
kmf_med_adt_latitude_iptw = KaplanMeierFitter()

kmf_med_abi_latitude_iptw.fit(
    latitude_iptw_med.query('adt_abi == 1').timerisk_treatment/30,
    latitude_iptw_med.query('adt_abi == 1').death_status,
    weights = latitude_iptw_med.query('adt_abi == 1')['weight'])

kmf_med_adt_latitude_iptw.fit(
    latitude_iptw_med.query('adt_abi == 0').timerisk_treatment/30,
    latitude_iptw_med.query('adt_abi == 0').death_status,
    weights = latitude_iptw_med.query('adt_abi == 0')['weight'])

# High KM curves 
kmf_high_abi_latitude_iptw = KaplanMeierFitter()
kmf_high_adt_latitude_iptw = KaplanMeierFitter()

kmf_high_abi_latitude_iptw.fit(
    latitude_iptw_high.query('adt_abi == 1').timerisk_treatment/30,
    latitude_iptw_high.query('adt_abi == 1').death_status,
    weights = latitude_iptw_high.query('adt_abi == 1')['weight'])

kmf_high_adt_latitude_iptw.fit(
    latitude_iptw_high.query('adt_abi == 0').timerisk_treatment/30,
    latitude_iptw_high.query('adt_abi == 0').death_status,
    weights = latitude_iptw_high.query('adt_abi == 0')['weight'])

# All KM curves 
kmf_all_abi_latitude_iptw = KaplanMeierFitter()
kmf_all_adt_latitude_iptw = KaplanMeierFitter()

kmf_all_abi_latitude_iptw.fit(
    latitude_iptw_all.query('adt_abi == 1').timerisk_treatment/30,
    latitude_iptw_all.query('adt_abi == 1').death_status,
    weights = latitude_iptw_all.query('adt_abi == 1')['weight'])

kmf_all_adt_latitude_iptw.fit(
    latitude_iptw_all.query('adt_abi == 0').timerisk_treatment/30,
    latitude_iptw_all.query('adt_abi == 0').death_status,
    weights = latitude_iptw_all.query('adt_abi == 0')['weight'])

<lifelines.KaplanMeierFitter:"KM_estimate", fitted with 6421.25 total observations, 3053.71 right-censored observations>

#### Calculating survival metrics 

In [239]:
abi_latitude_median_os = mos(kmf_low_abi_latitude_iptw,
                             kmf_med_abi_latitude_iptw,
                             kmf_high_abi_latitude_iptw,
                             kmf_all_abi_latitude_iptw)

adt_latitude_median_os = mos(kmf_low_adt_latitude_iptw,
                             kmf_med_adt_latitude_iptw,
                             kmf_high_adt_latitude_iptw,
                             kmf_all_adt_latitude_iptw)

In [240]:
latitude_iptw_all_imputed = latitude_iptw_all.copy()
latitude_iptw_all_imputed['albumin_diag'] = latitude_iptw_all_imputed['albumin_diag'].fillna(latitude_iptw_all_imputed['albumin_diag'].median())
latitude_iptw_all_imputed['weight_pct_change'] = latitude_iptw_all_imputed['weight_pct_change'].fillna(latitude_iptw_all_imputed['weight_pct_change'].median())
latitude_iptw_all_imputed['PSAMetDiagnosis'] = latitude_iptw_all_imputed['PSAMetDiagnosis'].fillna(latitude_iptw_all_imputed['PSAMetDiagnosis'].median())

In [241]:
latitude_hr_all = CoxPHFitter()
latitude_hr_all.fit(latitude_iptw_all_imputed,
                    duration_col = 'timerisk_treatment',
                    event_col = 'death_status',
                    formula = 'adt_abi + age + race + p_type + delta_met_diagnosis + met_cat + commercial + medicare + medicaid + ecog_2 + prim_treatment + PSAMetDiagnosis + albumin_diag + weight_pct_change + risk_score',
                    weights_col = 'weight', 
                    robust = True)

<lifelines.CoxPHFitter: fitted with 12656.2 total observations, 6942.01 right-censored observations>

In [242]:
latitude_all_rmst_mos_95 = rmst_mos_95ci(latitude_iptw_all,
                                         1000,
                                         'adt_abi',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [243]:
latitude_low_rmst_mos_95 = rmst_mos_95ci(latitude_iptw_low,
                                         1000,
                                         'adt_abi',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [244]:
latitude_med_rmst_mos_95 = rmst_mos_95ci(latitude_iptw_med,
                                         1000,
                                         'adt_abi',
                                         'death',
                                         ['age',
                                          'race',
                                          'p_type',
                                          'delta_met_diagnosis',
                                          'met_cat',
                                          'commercial',
                                          'medicare',
                                          'medicaid',
                                          'ecog_2',
                                          'prim_treatment',
                                          'PSAMetDiagnosis', 
                                          'albumin_diag', 
                                          'weight_pct_change',
                                          'risk_score'],
                                         ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                         60)

In [245]:
latitude_high_rmst_mos_95 = rmst_mos_95ci(latitude_iptw_high,
                                          1000,
                                          'adt_abi',
                                          'death',
                                          ['age',
                                           'race',
                                           'p_type',
                                           'delta_met_diagnosis',
                                           'met_cat',
                                           'commercial',
                                           'medicare',
                                           'medicaid',
                                           'ecog_2',
                                           'prim_treatment',
                                           'PSAMetDiagnosis', 
                                           'albumin_diag', 
                                           'weight_pct_change',
                                           'risk_score'],
                                          ['age', 'delta_met_diagnosis', 'PSAMetDiagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                          60)

In [246]:
latitude_data = [
    {'trial_name': 'LATITUDE', 
     'risk_group': 'low', 
     's_trt_mos': abi_latitude_median_os[0],
     's_trt_mos_95': latitude_low_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_latitude_median_os[0],
     's_cont_mos_95': latitude_low_rmst_mos_95.mos_B_95,
     's_mos_diff': abi_latitude_median_os[0] - adt_latitude_median_os[0], 
     'rct_trt_arm': 53.3,
     'rct_cont_arm': 36.5,
     'rct_mos_diff': 53.3-36.5,
     's_trt_rmst': restricted_mean_survival_time(kmf_low_abi_latitude_iptw, 55),
     's_trt_rmst_95': latitude_low_rmst_mos_95.rmst_A_95,
     's_cont_rmst': restricted_mean_survival_time(kmf_low_adt_latitude_iptw, 55),
     's_cont_rmst_95': latitude_low_rmst_mos_95.rmst_B_95,
     's_diff_rmst': restricted_mean_survival_time(kmf_low_abi_latitude_iptw, 55) - restricted_mean_survival_time(kmf_low_adt_latitude_iptw, 55),
     's_diff_rmst_95': latitude_low_rmst_mos_95.difference_rmst_95,
     'scount': latitude.query('risk_score <= @low_cutoff_latitude').shape[0]},
    
    {'trial_name': 'LATITUDE', 
     'risk_group': 'medium', 
     's_trt_mos': abi_latitude_median_os[1],
     's_trt_mos_95': latitude_med_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_latitude_median_os[1],
     's_cont_mos_95': latitude_med_rmst_mos_95.mos_B_95,
     's_mos_diff': abi_latitude_median_os[1] - adt_latitude_median_os[1], 
     'rct_trt_arm': 53.3,
     'rct_cont_arm': 36.5,
     'rct_mos_diff': 53.3-36.5,
     's_trt_rmst': restricted_mean_survival_time(kmf_med_abi_latitude_iptw, 55),
     's_trt_rmst_95': latitude_med_rmst_mos_95.rmst_A_95,
     's_cont_rmst': restricted_mean_survival_time(kmf_med_adt_latitude_iptw, 55),
     's_cont_rmst_95': latitude_med_rmst_mos_95.rmst_B_95,
     's_diff_rmst': restricted_mean_survival_time(kmf_med_abi_latitude_iptw, 55) - restricted_mean_survival_time(kmf_med_adt_latitude_iptw, 55),
     's_diff_rmst_95': latitude_med_rmst_mos_95.difference_rmst_95,
     'scount': latitude.query('risk_score < @high_cutoff_latitude and risk_score > @low_cutoff_latitude').shape[0]},
    
    {'trial_name': 'LATITUDE', 
     'risk_group': 'high', 
     's_trt_mos': abi_latitude_median_os[2],
     's_trt_mos_95': latitude_high_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_latitude_median_os[2],
     's_cont_mos_95': latitude_high_rmst_mos_95.mos_B_95,
     's_mos_diff': abi_latitude_median_os[2] - adt_latitude_median_os[2], 
     'rct_trt_arm': 53.3,
     'rct_cont_arm': 36.5,
     'rct_mos_diff': 53.3-36.5,
     's_trt_rmst': restricted_mean_survival_time(kmf_high_abi_latitude_iptw, 55),
     's_trt_rmst_95': latitude_high_rmst_mos_95.rmst_A_95,
     's_cont_rmst': restricted_mean_survival_time(kmf_high_adt_latitude_iptw, 55),
     's_cont_rmst_95': latitude_high_rmst_mos_95.rmst_B_95,
     's_diff_rmst': restricted_mean_survival_time(kmf_high_abi_latitude_iptw, 55) - restricted_mean_survival_time(kmf_high_adt_latitude_iptw, 55),
     's_diff_rmst_95': latitude_high_rmst_mos_95.difference_rmst_95,
     'scount': latitude.query('risk_score >= @high_cutoff_latitude').shape[0]},
    
    {'trial_name': 'LATITUDE', 
     'risk_group': 'all', 
     's_hr': latitude_hr_all.hazard_ratios_['adt_abi'],
     's_hr_95': [latitude_hr_all.summary.loc['adt_abi']['exp(coef) lower 95%'], latitude_hr_all.summary.loc['adt_abi']['exp(coef) upper 95%']],
     's_trt_mos': abi_latitude_median_os[3],
     's_trt_mos_95': latitude_all_rmst_mos_95.mos_A_95,
     's_cont_mos': adt_latitude_median_os[3],
     's_cont_mos_95': latitude_all_rmst_mos_95.mos_B_95,
     's_mos_diff': abi_latitude_median_os[3] - adt_latitude_median_os[3], 
     'rct_trt_arm': 53.3,
     'rct_cont_arm': 36.5,
     'rct_mos_diff': 53.3-36.5,
     'scount': latitude.shape[0]}
]

## Part 3. Combining dictionaries 

In [247]:
data_combined = chaarted_data + latitude_data

In [248]:
strials_mos_rmst_boot = pd.DataFrame(data_combined)

In [249]:
strials_mos_rmst_boot

Unnamed: 0,trial_name,risk_group,s_trt_mos,s_trt_mos_95,s_cont_mos,s_cont_mos_95,s_mos_diff,rct_trt_arm,rct_cont_arm,rct_mos_diff,s_trt_rmst,s_trt_rmst_95,s_cont_rmst,s_cont_rmst_95,s_diff_rmst,s_diff_rmst_95,scount,s_hr,s_hr_95
0,CHAARTED,low,64.266667,"[52.06666666666667, nan]",67.2,"[61.9325, 72.0]",-2.933333,57.6,44.0,13.6,48.72703,"[46.13612317343343, 51.23276174385276]",47.858002,"[46.7837613260294, 48.64792355207267]",0.869028,"[-1.832633684606853, 3.8607031875290265]",2452,,
1,CHAARTED,medium,53.633333,"[48.06666666666667, 59.766666666666666]",40.633333,"[38.3, 42.833333333333336]",13.0,57.6,44.0,13.6,45.966564,"[43.4785805786226, 48.20021582547429]",39.205572,"[38.250708661764335, 40.17180875478196]",6.760992,"[4.076642720472807, 9.196725232541223]",2412,,
2,CHAARTED,high,33.466667,"[30.866666666666667, 40.18583333333331]",23.0,"[21.766666666666666, 24.43583333333333]",10.466667,57.6,44.0,13.6,34.857454,"[31.773337846382287, 38.066628143010135]",26.834129,"[25.96522412365742, 27.785075966078836]",8.023324,"[4.812327495042032, 11.41617599395063]",1948,,
3,CHAARTED,all,50.066667,"[45.03, 58.03333333333333]",40.4,"[38.833333333333336, 41.9]",9.666667,57.6,44.0,13.6,,,,,,,6812,0.677007,"[0.5961952571037064, 0.7687731222068331]"
4,LATITUDE,low,inf,"[61.2, nan]",67.766667,"[63.46666666666667, 72.5]",inf,53.3,36.5,16.8,46.62329,"[46.96211160534267, 52.55956674977289]",45.176353,"[47.160523288858116, 48.8673236530062]",1.446937,"[-1.1008736490334547, 4.573271735286719]",2319,,
5,LATITUDE,medium,52.566667,"[34.63333333333333, nan]",40.5,"[38.365, 42.333333333333336]",12.066667,53.3,36.5,16.8,40.190502,"[37.8583820676278, 46.539484193413415]",37.419524,"[38.137716373819316, 39.881327183030095]",2.770978,"[-1.1586046091214963, 7.571774061246315]",2280,,
6,LATITUDE,high,26.833333,"[21.366666666666667, 35.1]",23.0,"[21.633333333333333, 24.433333333333334]",3.833333,53.3,36.5,16.8,28.788163,"[25.065864285418787, 34.282272216692476]",26.293633,"[25.927139984199155, 27.97050550476299]",2.49453,"[-1.864322971934037, 7.473714429001309]",1819,,
7,LATITUDE,all,49.966667,"[41.2, 58.63333333333333]",40.633333,"[39.13333333333333, 42.06666666666667]",9.333333,53.3,36.5,16.8,,,,,,,6418,0.84184,"[0.7128522148037691, 0.9941670306129405]"


In [250]:
strials_mos_rmst_boot.to_csv('strials_mos_rmst_boot.csv', index = False)