# Flatiron Health mCRC: Survival metrics for key elgibility criteria
**Background: Calculate survival metrics for emulated trials involving patients meeting key elgibliity criteria. Hazard ratio for the full cohort is calculated from a Cox-IPTW model. Restricted mean survival time and median overall survival are calculated for phenotypes using an IPTW-adjusted KM curve.** 

## Part 1: Preprocessing

### 1.1 Import packages and create necessary functions

In [1]:
import numpy as np
import pandas as pd

from scipy import stats

from sksurv.nonparametric import kaplan_meier_estimator
from survive import KaplanMeier, SurvivalData

from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.plotting import add_at_risk_counts
from lifelines.utils import median_survival_times, restricted_mean_survival_time
from lifelines.statistics import logrank_test

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer 
from sklearn.linear_model import LogisticRegression
from sklearn.utils import resample

import warnings

In [2]:
# Function that returns number of rows and count of unique PatientIDs for a dataframe. 
def row_ID(dataframe):
    row = dataframe.shape[0]
    ID = dataframe['PatientID'].nunique()
    return row, ID

In [3]:
# Find index for value closest to input value. 
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

In [4]:
# Calculates median overeall suvival for risk groups. 
def mos(low, med, high, comp):
    low_os = low.median_survival_time_
    med_os = med.median_survival_time_
    high_os = high.median_survival_time_
    comp_os = comp.median_survival_time_
    mos = [low_os, med_os, high_os, comp_os]
    return (mos)

In [5]:
def rmst_mos_95ci(df, num_samples, drug, event, items_list, numerical_features, rmst_time):
    
    """
    Estimate the 95% confidence interval for RMST and mOS using bootstrap resampling.

    Parameters:
    - df: DataFrame containing survival data
    - num_samples: Number of bootstrap samples
    - drug: Treatment indicator variable
    - event: Event type ('death' or 'progression')
    - items_list: Feature list for IPTW 
    - numerical_features: List of numerical features
    - rmst_time: Time to calculate RMST 

    Returns:
    - mos_A_95: mOS 95% CI for treatment
    - mos_B_95: mOS 95% CI for control
    - rmst_A_95: RMST 95% CI for treatment
    - rmst_B_95: RMST 95% CI for control
    - difference_rmst_95: RMST 95% CI for difference between treatment and control 
    """
    
    np.random.seed(42)
    mos_A = []
    mos_B = []
    rmst_A_list = []
    rmst_B_list = []
    differences_rmst = []
    
    # Define variables based on the event type
    if event == 'death':
        time_column = 'timerisk_treatment'
        status_column = 'death_status'
        
    else:
        time_column = 'time_prog_treatment'
        status_column = 'pfs_status'
        
    # Set up preprocessor for logistical regression which will be for IPTW  
    numerical_transformer = Pipeline(steps = [
        ('imputer', SimpleImputer(strategy = 'median')),
        ('std_scaler', StandardScaler())])
        
    categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')
    categorical_features = list(df.select_dtypes(include = ['category']).columns)
        
    preprocessor = ColumnTransformer(
        transformers = [
            ('num', numerical_transformer, numerical_features),
            ('cat', categorical_transformer, categorical_features)],
        remainder = 'passthrough')
    
    # Boostrap HR 
    for _ in range(num_samples):
        
        # Resample data with replacement
        resampled_df = resample(df).drop(columns = ['ps', 'weight'])
        
        # Calculated IPTW for the resampled group 
        df_x = preprocessor.fit_transform(resampled_df.filter(items = items_list))
                                           
        df_lr = LogisticRegression(max_iter = 1000)
        df_lr.fit(df_x, resampled_df[drug])
        warnings.filterwarnings("ignore")
                                           
        pred = df_lr.predict_proba(df_x)        
        resampled_df['ps'] = pred[:, 1]                          
        resampled_df['weight'] = (
                np.where(resampled_df[drug] == 1, 1/resampled_df['ps'], 1/(1 - resampled_df['ps'])))
    
        # mOS from IPTW-KM
        kmf_A = KaplanMeierFitter()
        kmf_A.fit(resampled_df.query(f'{drug} == 1')[time_column]/30,
                  resampled_df.query(f'{drug} == 1')[status_column], 
                  weights = resampled_df.query(f'{drug} == 1')['weight'])

        kmf_B = KaplanMeierFitter()
        kmf_B.fit(resampled_df.query(f'{drug} == 0')[time_column]/30,
                  resampled_df.query(f'{drug} == 0')[status_column], 
                  weights = resampled_df.query(f'{drug} == 0')['weight'])
    
        mos_A.append(kmf_A.median_survival_time_)
        mos_B.append(kmf_B.median_survival_time_)
        
        # RMST from IPTW-KM
        rmst_A = restricted_mean_survival_time(kmf_A, rmst_time)
        rmst_B = restricted_mean_survival_time(kmf_B, rmst_time)
        
        rmst_A_list.append(rmst_A)
        rmst_B_list.append(rmst_B)
        differences_rmst.append(rmst_A - rmst_B)

    # Calculate the 95% confidence interval
    results = pd.Series({
    'mos_A_95': np.percentile(mos_A, [2.5, 97.5]),
    'mos_B_95': np.percentile(mos_B, [2.5, 97.5]),
    'rmst_A_95': np.percentile(rmst_A_list, [2.5, 97.5]),
    'rmst_B_95': np.percentile(rmst_B_list, [2.5, 97.5]),
    'difference_rmst_95': np.percentile(differences_rmst, [2.5, 97.5])
    })
    
    return results

## Part 2: In silico trials 

### FIRE-3: FOLFIRI plus cetuximab vs. FOLFIRI plus bevacizumab in KRAS wild-type 

**INCLUSION**
* Untreated metastatic colorectal cancer
* First-line recepit of FOLFIRI/FOLFOX plus cetuximab or FOLFIRI/FOLFOX plus bevacizumab
* KRAS wild-type 

#### FOLFIRI plus Cetuximab

In [6]:
df_full = pd.read_csv('df_risk_crude.csv', index_col = 'PatientID', dtype = {'death_status': bool})
df_full.index.nunique()

34315

In [7]:
line_therapy = pd.read_csv('LineOfTherapy.csv')

In [8]:
line_therapy_fl = (
    line_therapy[line_therapy['PatientID'].isin(df_full.index)]
    .query('LineNumber == 1')
    .query('IsMaintenanceTherapy == False'))

In [9]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('FOLFOX|Cetuximab')].LineName.value_counts().head(10)

FOLFOX,Bevacizumab            6338
FOLFOX                        4140
FOLFOX,Bevacizumab-Awwb       1356
FOLFOX,Bevacizumab-Bvzr        462
FOLFOX,Panitumumab             456
FOLFIRI,Cetuximab              364
FOLFOXIRI                      253
FOLFOX,Cetuximab               245
FOLFOXIRI,Bevacizumab          228
FOLFOXIRI,Bevacizumab-Awwb     140
Name: LineName, dtype: int64

In [10]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('FOLFIRI|Cetuximab')].LineName.value_counts().head(10)

FOLFIRI,Bevacizumab         2240
FOLFIRI                      914
FOLFIRI,Bevacizumab-Awwb     557
FOLFIRI,Cetuximab            364
FOLFIRI,Panitumumab          261
FOLFOX,Cetuximab             245
FOLFIRI,Bevacizumab-Bvzr     165
Irinotecan,Cetuximab          90
Cetuximab                     70
FOLFIRI,Ziv-Aflibercept       56
Name: LineName, dtype: int64

In [11]:
fxi_cet = (
    line_therapy_fl
    .query('LineName == "FOLFIRI,Cetuximab" or LineName == "FOLFOX,Cetuximab"')
    [['PatientID', 'StartDate']]
)

In [12]:
fxi_cet.loc[:,'fxi_cet'] = 1

In [13]:
row_ID(fxi_cet)

(609, 568)

In [14]:
fxi_cet['StartDate'] = pd.to_datetime(fxi_cet['StartDate'])

In [15]:
fxi_cet = (
    fxi_cet
    .sort_values(['PatientID', 'StartDate'], ascending = [True, True])
    .drop_duplicates(subset = 'PatientID', keep = 'first')
)

In [16]:
row_ID(fxi_cet)

(568, 568)

#### FOLFIRI plus Bevacizumab

In [17]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('FOLFIRI|Bevacizumab')].LineName.value_counts().head(10)

FOLFOX,Bevacizumab                     6338
FOLFIRI,Bevacizumab                    2240
FOLFOX,Bevacizumab-Awwb                1356
FOLFIRI                                 914
CAPEOX,Bevacizumab                      764
FOLFIRI,Bevacizumab-Awwb                557
FOLFOX,Bevacizumab-Bvzr                 462
Fluorouracil,Leucovorin,Bevacizumab     423
Capecitabine,Bevacizumab                405
FOLFIRI,Cetuximab                       364
Name: LineName, dtype: int64

In [18]:
line_therapy_fl[line_therapy_fl['LineName'].str.contains('FOLFOX|Bevacizumab')].LineName.value_counts().head(10)

FOLFOX,Bevacizumab                     6338
FOLFOX                                 4140
FOLFIRI,Bevacizumab                    2240
FOLFOX,Bevacizumab-Awwb                1356
CAPEOX,Bevacizumab                      764
FOLFIRI,Bevacizumab-Awwb                557
FOLFOX,Bevacizumab-Bvzr                 462
FOLFOX,Panitumumab                      456
Fluorouracil,Leucovorin,Bevacizumab     423
Capecitabine,Bevacizumab                405
Name: LineName, dtype: int64

In [19]:
fxi_bev_comb = [
    'FOLFIRI,Bevacizumab',
    'FOLFIRI,Bevacizumab-Awwb',
    'FOLFIRI,Bevacizumab-Bvzr',
    'FOLFOX,Bevacizumab',
    'FOLFOX,Bevacizumab-Awwb',
    'FOLFOX,Bevacizumab-Bvzr'
]

fxi_bev = (
    line_therapy_fl
    .query('LineName == @fxi_bev_comb')
    [['PatientID', 'StartDate']]
)

In [20]:
fxi_bev.loc[:,'fxi_cet'] = 0

In [21]:
row_ID(fxi_bev)

(11118, 10785)

In [22]:
fxi_bev['StartDate'] = pd.to_datetime(fxi_bev['StartDate'])

In [23]:
fxi_bev = (
    fxi_bev
    .sort_values(['PatientID', 'StartDate'], ascending = [True, True])
    .drop_duplicates(subset = 'PatientID', keep = 'first')
)

In [24]:
row_ID(fxi_bev)

(10785, 10785)

In [25]:
fire = pd.concat([fxi_cet, fxi_bev])

In [26]:
row_ID(fire)

(11353, 11353)

In [27]:
fire = pd.merge(fire, df_full, on = 'PatientID', how = 'left')

In [28]:
row_ID(fire)

(11353, 11353)

#### KRAS wild type 

In [29]:
biomarkers = pd.read_csv('Enhanced_MetCRCBiomarkers.csv')

In [30]:
biomarkers = biomarkers[biomarkers['PatientID'].isin(fire['PatientID'])]

In [31]:
row_ID(biomarkers)

(46879, 10576)

In [32]:
biomarkers = pd.merge(biomarkers, fire[['PatientID', 'StartDate']], on = 'PatientID', how = 'left')

In [33]:
row_ID(biomarkers)

(46879, 10576)

In [34]:
biomarkers['ResultDate'] = pd.to_datetime(biomarkers['ResultDate'])

In [35]:
biomarkers['SpecimenReceivedDate'] = pd.to_datetime(biomarkers['SpecimenReceivedDate'])

In [36]:
biomarkers.loc[:, 'result_date'] = (
    np.where(biomarkers['ResultDate'].isna(), biomarkers['SpecimenReceivedDate'], biomarkers['ResultDate'])
)

In [37]:
biomarkers.loc[:, 'date_diff'] = (biomarkers['result_date'] - biomarkers['StartDate']).dt.days

In [38]:
kras = (
    biomarkers
    .query('BiomarkerName == "KRAS"')
    .query('date_diff <=30')
    .query('BiomarkerStatus == "Mutation negative"')
    [['PatientID', 'BiomarkerStatus']]
    .rename(columns = {'BiomarkerStatus': 'kras_n'})
    .drop_duplicates(subset = 'PatientID', keep = 'first')
)

In [39]:
row_ID(kras)

(3645, 3645)

In [40]:
fire = pd.merge(fire, kras, on  = 'PatientID', how = 'left')

In [41]:
row_ID(fire)

(11353, 11353)

#### Time from treatment to death or censor 

In [42]:
mortality_tr = pd.read_csv('mortality_cleaned_tr.csv')

In [43]:
mortality_te = pd.read_csv('mortality_cleaned_te.csv')

In [44]:
mortality_tr = mortality_tr[['PatientID', 'death_date', 'last_activity']]

In [45]:
mortality_te = mortality_te[['PatientID', 'death_date', 'last_activity']]

In [46]:
mortality = pd.concat([mortality_tr, mortality_te], ignore_index = True)
print(len(mortality), mortality.PatientID.is_unique)

34315 True


In [47]:
mortality.loc[:, 'last_activity'] = pd.to_datetime(mortality['last_activity'])

In [48]:
mortality.loc[:, 'death_date'] = pd.to_datetime(mortality['death_date'])

In [49]:
len(mortality)

34315

In [50]:
fire = pd.merge(fire, mortality, on = 'PatientID', how = 'left')

In [51]:
len(fire)

11353

In [52]:
conditions = [
    (fire['death_status'] == 1),
    (fire['death_status'] == 0)]

choices = [
    (fire['death_date'] - fire['StartDate']).dt.days,
    (fire['last_activity'] - fire['StartDate']).dt.days]

fire.loc[:, 'timerisk_treatment'] = np.select(conditions, choices)

In [53]:
fire = fire.query('timerisk_treatment >= 0')

In [54]:
row_ID(fire)

(11343, 11343)

#### Patient count 

In [55]:
fire = (
    fire
    .query('kras_n == "Mutation negative"')
)

In [56]:
low_cutoff_fire = fire.risk_score.quantile(1/3)

In [57]:
high_cutoff_fire = fire.risk_score.quantile(2/3)

In [58]:
print('FOLFIRI + Cetuximab total:',  fire.query('fxi_cet == 1').shape[0])
print('High risk:', fire.query('fxi_cet == 1').query('risk_score >= @high_cutoff_fire').shape[0])
print('Med risk:', fire.query('fxi_cet == 1').query('risk_score < @high_cutoff_fire and risk_score > @low_cutoff_fire').shape[0])
print('Low risk:', fire.query('fxi_cet == 1').query('risk_score <= @low_cutoff_fire').shape[0])

FOLFIRI + Cetuximab total: 499
High risk: 139
Med risk: 186
Low risk: 174


In [59]:
print('FOLFIRI + Bevacizumab:',  fire.query('fxi_cet == 0').shape[0])
print('High risk:', fire.query('fxi_cet == 0').query('risk_score >= @high_cutoff_fire').shape[0])
print('Med risk:', fire.query('fxi_cet == 0').query('risk_score < @high_cutoff_fire and risk_score > @low_cutoff_fire').shape[0])
print('Low risk:', fire.query('fxi_cet == 0').query('risk_score <= @low_cutoff_fire').shape[0])

FOLFIRI + Bevacizumab: 3142
High risk: 1075
Med risk: 1027
Low risk: 1040


#### Survival curves with covariate balancing 

In [60]:
fire = fire.set_index('PatientID')

In [61]:
fire['met_cat'] = pd.cut(fire['met_year'],
                         bins = [2010, 2015, float('inf')],
                         labels = ['11-15', '16-22'])

In [62]:
conditions = [
    ((fire['ecog_diagnosis'] == "1.0") | (fire['ecog_diagnosis'] == "0.0")),  
    ((fire['ecog_diagnosis'] == "2.0") | (fire['ecog_diagnosis'] == "3.0"))
]

choices = ['lt_2', 'gte_2']

fire['ecog_2'] = np.select(conditions, choices, default = 'unknown')

In [63]:
fire_iptw = fire.filter(items = ['death_status',
                                 'timerisk_treatment',
                                 'fxi_cet',
                                 'age',
                                 'gender',
                                 'race',
                                 'p_type',
                                 'crc_site',
                                 'met_cat',
                                 'delta_met_diagnosis',
                                 'commercial',
                                 'medicare',
                                 'medicaid',
                                 'ecog_2', 
                                 'ses',
                                 'albumin_diag',
                                 'weight_pct_change',
                                 'risk_score'])

In [64]:
fire_iptw.dtypes

death_status               bool
timerisk_treatment      float64
fxi_cet                   int64
age                       int64
gender                   object
race                     object
p_type                   object
crc_site                 object
met_cat                category
delta_met_diagnosis       int64
commercial              float64
medicare                float64
medicaid                float64
ecog_2                   object
ses                     float64
albumin_diag            float64
weight_pct_change       float64
risk_score              float64
dtype: object

In [65]:
to_be_categorical = list(fire_iptw.select_dtypes(include = ['object']).columns)

In [66]:
to_be_categorical

['gender', 'race', 'p_type', 'crc_site', 'ecog_2']

In [67]:
to_be_categorical.append('met_cat')

In [68]:
to_be_categorical.append('ses')

In [69]:
# Convert variables in list to categorical.
for x in list(to_be_categorical):
    fire_iptw[x] = fire_iptw[x].astype('category')

In [70]:
# List of numeric variables, excluding binary variables. 
numerical_features = ['age', 'delta_met_diagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score']

# Transformer will first calculate column median and impute, and then apply a standard scaler. 
numerical_transformer = Pipeline(steps = [
    ('imputer', SimpleImputer(strategy = 'median')),
    ('std_scaler', StandardScaler())])

In [71]:
# List of categorical features.
categorical_features = list(fire_iptw.select_dtypes(include = ['category']).columns)

# One-hot-encode categorical features.
categorical_transformer = OneHotEncoder(handle_unknown = 'ignore')

In [72]:
preprocessor = ColumnTransformer(
    transformers = [
        ('num', numerical_transformer, numerical_features),
        ('cat', categorical_transformer, categorical_features)],
    remainder = 'passthrough')

In [73]:
fire_iptw_low = (
    fire_iptw
    .query('risk_score <= @low_cutoff_fire'))

fire_iptw_med = (
    fire_iptw
    .query('risk_score < @high_cutoff_fire and risk_score > @low_cutoff_fire'))

fire_iptw_high = (
    fire_iptw
    .query('risk_score >= @high_cutoff_fire'))

fire_iptw_all = fire_iptw

In [74]:
fire_low_x = preprocessor.fit_transform(fire_iptw_low.filter(items = ['age',
                                                                      'gender',
                                                                      'race',
                                                                      'p_type',
                                                                      'crc_site',
                                                                      'met_cat',
                                                                      'delta_met_diagnosis',
                                                                      'commercial',
                                                                      'medicare',
                                                                      'medicaid',
                                                                      'ecog_2', 
                                                                      'ses', 
                                                                      'albumin_diag', 
                                                                      'weight_pct_change',
                                                                      'risk_score']))

fire_med_x = preprocessor.fit_transform(fire_iptw_med.filter(items = ['age',
                                                                      'gender',
                                                                      'race',
                                                                      'p_type',
                                                                      'crc_site',
                                                                      'met_cat',
                                                                      'delta_met_diagnosis',
                                                                      'commercial',
                                                                      'medicare',
                                                                      'medicaid',
                                                                      'ecog_2', 
                                                                      'ses', 
                                                                      'albumin_diag', 
                                                                      'weight_pct_change',
                                                                      'risk_score']))

fire_high_x = preprocessor.fit_transform(fire_iptw_high.filter(items = ['age',
                                                                        'gender',
                                                                        'race',
                                                                        'p_type',
                                                                        'crc_site',
                                                                        'met_cat',
                                                                        'delta_met_diagnosis',
                                                                        'commercial',
                                                                        'medicare',
                                                                        'medicaid',
                                                                        'ecog_2', 
                                                                        'ses', 
                                                                        'albumin_diag', 
                                                                        'weight_pct_change',
                                                                        'risk_score']))

fire_all_x = preprocessor.fit_transform(fire_iptw_all.filter(items = ['age',
                                                                      'gender',
                                                                      'race',
                                                                      'p_type',
                                                                      'crc_site',
                                                                      'met_cat',
                                                                      'delta_met_diagnosis',
                                                                      'commercial',
                                                                      'medicare',
                                                                      'medicaid',
                                                                      'ecog_2', 
                                                                      'ses', 
                                                                      'albumin_diag', 
                                                                      'weight_pct_change',
                                                                      'risk_score']))

In [75]:
lr_fire_low = LogisticRegression(max_iter = 1000)
lr_fire_low.fit(fire_low_x, fire_iptw_low['fxi_cet'])

LogisticRegression(max_iter=1000)

In [76]:
lr_fire_med = LogisticRegression(max_iter = 1000)
lr_fire_med.fit(fire_med_x, fire_iptw_med['fxi_cet'])

LogisticRegression(max_iter=1000)

In [77]:
lr_fire_high = LogisticRegression(max_iter = 1000)
lr_fire_high.fit(fire_high_x, fire_iptw_high['fxi_cet'])

LogisticRegression(max_iter=1000)

In [78]:
lr_fire_all = LogisticRegression(max_iter = 1000)
lr_fire_all.fit(fire_all_x, fire_iptw_all['fxi_cet'])

LogisticRegression(max_iter=1000)

In [79]:
pred_low = lr_fire_low.predict_proba(fire_low_x)
pred_med = lr_fire_med.predict_proba(fire_med_x)
pred_high = lr_fire_high.predict_proba(fire_high_x)
pred_all = lr_fire_all.predict_proba(fire_all_x)

In [80]:
fire_iptw_low['ps'] = pred_low[:, 1]
fire_iptw_med['ps'] = pred_med[:, 1]
fire_iptw_high['ps'] = pred_high[:, 1]
fire_iptw_all['ps'] = pred_all[:, 1]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [81]:
fire_iptw_low['weight'] = (
    np.where(fire_iptw_low['fxi_cet'] == 1, 1/fire_iptw_low['ps'], 1/(1 - fire_iptw_low['ps'])))

fire_iptw_med['weight'] = (
    np.where(fire_iptw_med['fxi_cet'] == 1, 1/fire_iptw_med['ps'], 1/(1 - fire_iptw_med['ps'])))

fire_iptw_high['weight'] = (
    np.where(fire_iptw_high['fxi_cet'] == 1, 1/fire_iptw_high['ps'], 1/(1 - fire_iptw_high['ps'])))

fire_iptw_all['weight'] = (
    np.where(fire_iptw_all['fxi_cet'] == 1, 1/fire_iptw_all['ps'], 1/(1 - fire_iptw_all['ps'])))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [82]:
# Low KM curves
kmf_low_cet_f_iptw = KaplanMeierFitter()
kmf_low_bev_f_iptw = KaplanMeierFitter()

kmf_low_cet_f_iptw.fit(
    fire_iptw_low.query('fxi_cet == 1').timerisk_treatment/30,
    fire_iptw_low.query('fxi_cet == 1').death_status,
    weights = fire_iptw_low.query('fxi_cet == 1')['weight'])

kmf_low_bev_f_iptw.fit(
    fire_iptw_low.query('fxi_cet == 0').timerisk_treatment/30,
    fire_iptw_low.query('fxi_cet == 0').death_status,
    weights = fire_iptw_low.query('fxi_cet == 0')['weight'])

# Med KM curves
kmf_med_cet_f_iptw = KaplanMeierFitter()
kmf_med_bev_f_iptw = KaplanMeierFitter()

kmf_med_cet_f_iptw.fit(
    fire_iptw_med.query('fxi_cet == 1').timerisk_treatment/30,
    fire_iptw_med.query('fxi_cet == 1').death_status,
    weights = fire_iptw_med.query('fxi_cet == 1')['weight'])

kmf_med_bev_f_iptw.fit(
    fire_iptw_med.query('fxi_cet == 0').timerisk_treatment/30,
    fire_iptw_med.query('fxi_cet == 0').death_status,
    weights = fire_iptw_med.query('fxi_cet == 0')['weight'])

# High KM curves 
kmf_high_cet_f_iptw = KaplanMeierFitter()
kmf_high_bev_f_iptw = KaplanMeierFitter()

kmf_high_cet_f_iptw.fit(
    fire_iptw_high.query('fxi_cet == 1').timerisk_treatment/30,
    fire_iptw_high.query('fxi_cet == 1').death_status,
    weights = fire_iptw_high.query('fxi_cet == 1')['weight'])

kmf_high_bev_f_iptw.fit(
    fire_iptw_high.query('fxi_cet == 0').timerisk_treatment/30,
    fire_iptw_high.query('fxi_cet == 0').death_status,
    weights = fire_iptw_high.query('fxi_cet == 0')['weight'])

# All KM curves 
kmf_all_cet_f_iptw = KaplanMeierFitter()
kmf_all_bev_f_iptw = KaplanMeierFitter()

kmf_all_cet_f_iptw.fit(
    fire_iptw_all.query('fxi_cet == 1').timerisk_treatment/30,
    fire_iptw_all.query('fxi_cet == 1').death_status,
    weights = fire_iptw_all.query('fxi_cet == 1')['weight'])

kmf_all_bev_f_iptw.fit(
    fire_iptw_all.query('fxi_cet == 0').timerisk_treatment/30,
    fire_iptw_all.query('fxi_cet == 0').death_status,
    weights = fire_iptw_all.query('fxi_cet == 0')['weight'])

  It's important to know that the naive variance estimates of the coefficients are biased. Instead use Monte Carlo to
  estimate the variances. See paper "Variance estimation when using inverse probability of treatment weighting (IPTW) with survival analysis"
  or "Adjusted Kaplan-Meier estimator and log-rank test with inverse probability of treatment weighting for survival data."
                  


<lifelines.KaplanMeierFitter:"KM_estimate", fitted with 3640.42 total observations, 1380.8 right-censored observations>

#### Calculating survival metrics

In [83]:
cet_fire_median_os = mos(kmf_low_cet_f_iptw,
                         kmf_med_cet_f_iptw,
                         kmf_high_cet_f_iptw,
                         kmf_all_cet_f_iptw)

bev_fire_median_os = mos(kmf_low_bev_f_iptw,
                         kmf_med_bev_f_iptw,
                         kmf_high_bev_f_iptw, 
                         kmf_all_bev_f_iptw)

In [84]:
fire_iptw_all_imputed = fire_iptw_all.copy()
fire_iptw_all_imputed['albumin_diag'] = fire_iptw_all_imputed['albumin_diag'].fillna(fire_iptw_all_imputed['albumin_diag'].median())
fire_iptw_all_imputed['weight_pct_change'] = fire_iptw_all_imputed['weight_pct_change'].fillna(fire_iptw_all_imputed['weight_pct_change'].median())
fire_iptw_all_imputed['ses'] = fire_iptw_all_imputed['ses'].cat.add_categories('unknown')
fire_iptw_all_imputed['ses'] = fire_iptw_all_imputed['ses'].fillna('unknown')

In [85]:
fire_hr_all = CoxPHFitter()
fire_hr_all.fit(fire_iptw_all_imputed,
                duration_col = 'timerisk_treatment', 
                event_col = 'death_status', 
                formula = 'fxi_cet + age + gender + race + p_type + crc_site + met_cat + delta_met_diagnosis + commercial + medicare + medicaid + ecog_2 + ses + albumin_diag + weight_pct_change + risk_score',
                weights_col = 'weight',
                robust = True)

<lifelines.CoxPHFitter: fitted with 7266.24 total observations, 2898.89 right-censored observations>

In [86]:
fire_all_rmst_mos_95 = rmst_mos_95ci(fire_iptw_all,
                                     1000,
                                     'fxi_cet',
                                     'death',
                                     ['age',
                                      'gender',
                                      'race',
                                      'p_type',
                                      'crc_site',
                                      'met_cat',
                                      'delta_met_diagnosis',
                                      'commercial',
                                      'medicare',
                                      'medicaid',
                                      'ecog_2', 
                                      'ses', 
                                      'albumin_diag', 
                                      'weight_pct_change',
                                      'risk_score'],
                                     ['age', 'delta_met_diagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                     48)

In [87]:
fire_low_rmst_mos_95 = rmst_mos_95ci(fire_iptw_low,
                                     1000,
                                     'fxi_cet',
                                     'death',
                                     ['age',
                                      'gender',
                                      'race',
                                      'p_type',
                                      'crc_site',
                                      'met_cat',
                                      'delta_met_diagnosis',
                                      'commercial',
                                      'medicare',
                                      'medicaid',
                                      'ecog_2', 
                                      'ses', 
                                      'albumin_diag', 
                                      'weight_pct_change',
                                      'risk_score'],
                                     ['age', 'delta_met_diagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                     48)

In [88]:
fire_med_rmst_mos_95 = rmst_mos_95ci(fire_iptw_med,
                                     1000,
                                     'fxi_cet',
                                     'death',
                                     ['age',
                                      'gender',
                                      'race',
                                      'p_type',
                                      'crc_site',
                                      'met_cat',
                                      'delta_met_diagnosis',
                                      'commercial',
                                      'medicare',
                                      'medicaid',
                                      'ecog_2', 
                                      'ses', 
                                      'albumin_diag', 
                                      'weight_pct_change',
                                      'risk_score'],
                                     ['age', 'delta_met_diagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                     48)

In [89]:
fire_high_rmst_mos_95 = rmst_mos_95ci(fire_iptw_high,
                                     1000,
                                     'fxi_cet',
                                     'death',
                                     ['age',
                                      'gender',
                                      'race',
                                      'p_type',
                                      'crc_site',
                                      'met_cat',
                                      'delta_met_diagnosis',
                                      'commercial',
                                      'medicare',
                                      'medicaid',
                                      'ecog_2', 
                                      'ses', 
                                      'albumin_diag', 
                                      'weight_pct_change',
                                      'risk_score'],
                                     ['age', 'delta_met_diagnosis', 'albumin_diag', 'weight_pct_change', 'risk_score'],
                                     48)

In [90]:
fire3_data = [
    {'trial_name': 'FIRE-3', 
     'risk_group': 'low', 
     'r_trt_mos': cet_fire_median_os[0],
     'r_trt_mos_95': fire_low_rmst_mos_95.mos_A_95,
     'r_cont_mos': bev_fire_median_os[0],
     'r_cont_mos_95': fire_low_rmst_mos_95.mos_B_95,
     'r_mos_diff': cet_fire_median_os[0] - bev_fire_median_os[0], 
     'rct_trt_arm': 33.1, 
     'rct_cont_arm': 25.6,
     'rct_mos_diff': 33.1-25.6,
     'trt_rmst': restricted_mean_survival_time(kmf_low_cet_f_iptw, 48),
     'trt_rmst_95': fire_low_rmst_mos_95.rmst_A_95,
     'cont_rmst': restricted_mean_survival_time(kmf_low_bev_f_iptw, 48),
     'cont_rmst_95': fire_low_rmst_mos_95.rmst_B_95,
     'diff_rmst': restricted_mean_survival_time(kmf_low_cet_f_iptw, 48) - restricted_mean_survival_time(kmf_low_bev_f_iptw, 48),
     'diff_rmst_95': fire_low_rmst_mos_95.difference_rmst_95,
     'rcount': fire.query('risk_score <= @low_cutoff_fire').shape[0],
     'rcount_chemo': fire.query('risk_score <= @low_cutoff_fire').shape[0]},
    
    {'trial_name': 'FIRE-3', 
     'risk_group': 'medium', 
     'r_trt_mos': cet_fire_median_os[1],
     'r_trt_mos_95': fire_med_rmst_mos_95.mos_A_95,
     'r_cont_mos': bev_fire_median_os[1],
     'r_cont_mos_95': fire_med_rmst_mos_95.mos_B_95,
     'r_mos_diff': cet_fire_median_os[1] - bev_fire_median_os[1], 
     'rct_trt_arm': 33.1, 
     'rct_cont_arm': 25.6,
     'rct_mos_diff': 33.1-25.6,
     'trt_rmst': restricted_mean_survival_time(kmf_med_cet_f_iptw, 48),
     'trt_rmst_95': fire_med_rmst_mos_95.rmst_A_95,
     'cont_rmst': restricted_mean_survival_time(kmf_med_bev_f_iptw, 48),
     'cont_rmst_95': fire_med_rmst_mos_95.rmst_B_95,
     'diff_rmst': restricted_mean_survival_time(kmf_med_cet_f_iptw, 48) - restricted_mean_survival_time(kmf_med_bev_f_iptw, 48),
     'diff_rmst_95': fire_med_rmst_mos_95.difference_rmst_95,
     'rcount': fire.query('risk_score < @high_cutoff_fire and risk_score > @low_cutoff_fire').shape[0],
     'rcount_chemo': fire.query('risk_score < @high_cutoff_fire and risk_score > @low_cutoff_fire').shape[0]},
    
    {'trial_name': 'FIRE-3', 
     'risk_group': 'high', 
     'r_trt_mos': cet_fire_median_os[2],
     'r_trt_mos_95': fire_high_rmst_mos_95.mos_A_95,
     'r_cont_mos': bev_fire_median_os[2],
     'r_cont_mos_95': fire_high_rmst_mos_95.mos_B_95,
     'r_mos_diff': cet_fire_median_os[2] - bev_fire_median_os[2], 
     'rct_trt_arm': 33.1, 
     'rct_cont_arm': 25.6,
     'rct_mos_diff': 33.1-25.6,
     'trt_rmst': restricted_mean_survival_time(kmf_high_cet_f_iptw, 48),
     'trt_rmst_95': fire_high_rmst_mos_95.rmst_A_95,
     'cont_rmst': restricted_mean_survival_time(kmf_high_bev_f_iptw, 48),
     'cont_rmst_95': fire_high_rmst_mos_95.rmst_B_95,
     'diff_rmst': restricted_mean_survival_time(kmf_high_cet_f_iptw, 48) - restricted_mean_survival_time(kmf_high_bev_f_iptw, 48),
     'diff_rmst_95': fire_high_rmst_mos_95.difference_rmst_95,
     'rcount': fire.query('risk_score >= @high_cutoff_fire').shape[0],
     'rcount_chemo': fire.query('risk_score >= @high_cutoff_fire').shape[0]},
    
    {'trial_name': 'FIRE-3', 
     'risk_group': 'all', 
     'r_hr': fire_hr_all.hazard_ratios_['fxi_cet'],
     'r_hr_95': [fire_hr_all.summary.loc['fxi_cet']['exp(coef) lower 95%'], fire_hr_all.summary.loc['fxi_cet']['exp(coef) upper 95%']],
     'r_trt_mos': cet_fire_median_os[3],
     'r_trt_mos_95': fire_all_rmst_mos_95.mos_A_95,
     'r_cont_mos': bev_fire_median_os[3],
     'r_cont_mos_95': fire_all_rmst_mos_95.mos_B_95,
     'r_mos_diff': cet_fire_median_os[3] - bev_fire_median_os[3], 
     'rct_trt_arm': 33.1, 
     'rct_cont_arm': 25.6,
     'rct_mos_diff': 33.1-25.6,
     'rcount': fire.shape[0], 
     'rcount_chemo': fire.shape[0]}
]

In [91]:
rtrials_mos_rmst_boot = pd.DataFrame(fire3_data)

In [92]:
rtrials_mos_rmst_boot

Unnamed: 0,trial_name,risk_group,r_trt_mos,r_trt_mos_95,r_cont_mos,r_cont_mos_95,r_mos_diff,rct_trt_arm,rct_cont_arm,rct_mos_diff,trt_rmst,trt_rmst_95,cont_rmst,cont_rmst_95,diff_rmst,diff_rmst_95,rcount,rcount_chemo,r_hr,r_hr_95
0,FIRE-3,low,52.966667,"[44.6, 79.33333333333333]",45.666667,"[41.833333333333336, 49.7375]",7.3,33.1,25.6,7.5,38.281983,"[36.04663466888957, 40.48323978734367]",36.504141,"[35.48581458249381, 37.41696911469564]",1.777842,"[-0.5838694387336758, 4.208611907412448]",1214,1214,,
1,FIRE-3,medium,31.366667,"[25.766666666666666, 38.666666666666664]",27.433333,"[25.6, 28.940833333333327]",3.933333,33.1,25.6,7.5,29.867315,"[27.29896583015385, 32.73738804047254]",28.195098,"[27.17145420900001, 29.25245036130986]",1.672217,"[-1.1156099852901362, 4.768794970309147]",1213,1213,,
2,FIRE-3,high,12.9,"[9.228333333333333, 16.066666666666666]",12.833333,"[12.0, 13.566666666666666]",0.066667,33.1,25.6,7.5,17.359847,"[14.447207444085366, 20.420586102958595]",15.850093,"[15.056639934480772, 16.571586885290465]",1.509754,"[-1.5206126846172205, 4.746733916826746]",1214,1214,,
3,FIRE-3,all,31.366667,"[27.866666666666667, 35.03333333333333]",24.9,"[24.032500000000002, 26.433333333333334]",6.466667,33.1,25.6,7.5,,,,,,,3641,3641,0.927254,"[0.8059101056843947, 1.0668680479449384]"


In [93]:
rtrials_mos_rmst_boot.to_csv('rtrials_mos_rmst_boot.csv', index = False)