# Cox PH analysis

### Environment setup

In [None]:
pip install lifelines

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns
import math
import scipy.stats as sp
from scipy.cluster.hierarchy import (
    linkage,
    dendrogram,
    fcluster,
    set_link_color_palette,
)
from scipy.spatial.distance import squareform
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import statsmodels as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor
from ukb_cox_proportional_hazards_utils import (
    compute_is_cancer_at_recruitment,
    compute_survival_time_with_age_for_label,
)
from outlier_methods import detect_outliers
from lifelines import CoxPHFitter
import json

In [None]:
s3_path = 's3://file_path/analysis/'

In [None]:
# Cleaner variable names

with open('ukb_feature_rename_map.json', 'r') as f:
    rename_mapping = json.load(f)

### Prepare data

Exclude patients, calculate age at diagnosis and survival as a factor of diagnosis

In [None]:
# Get cancer info

df_merged = pd.read_csv('s3://file_path/file_1.csv', low_memory=False)
df_merged['is_cancer-0'] = df_merged.apply(compute_is_cancer_at_recruitment, axis=1)
df_label = pd.read_csv('s3://file_path/file_2.csv')

controls_with_othercancer = (
    pd.merge(df_merged[['eid']], df_label, on='eid', how='left')
    .query(
        'label_first_occurred_date.isna() & othercancer_first_occurred_date.notna()',
        engine='python',
    )
    .eid.unique()
)
print(
    f'Number of participants who developed cancer other than CRC: {len(controls_with_othercancer)}'
)
survival_df = df_merged.loc[~df_merged.eid.isin(controls_with_othercancer), :]
othercancer_pre_crc = (
    pd.merge(df_merged[['eid']], df_label, on='eid', how='left')
    .query(
        '(label_first_occurred_date.notna()) & (othercancer_first_occurred_date.notna()) & (label_first_occurred_date>othercancer_first_occurred_date)',
        engine='python',
    )
    .eid.unique()
)
print(
    f'Number of participants who developed other cancer prior to CRC: {len(othercancer_pre_crc)}'
)
survival_df = survival_df.loc[~survival_df.eid.isin(othercancer_pre_crc), :]
othercancer_with_crc = (
    pd.merge(df_merged[['eid']], df_label, on='eid', how='left')
    .query(
        '(label_first_occurred_date.notna()) & (othercancer_first_occurred_date.notna()) & (label_first_occurred_date==othercancer_first_occurred_date)',
        engine='python',
    )
    .eid.unique()
)
print(
    f'Number of participants who developed other cancer same time as CRC: {len(othercancer_with_crc)}'
)
survival_df = survival_df.loc[~survival_df.eid.isin(othercancer_with_crc), :]
cancer_prevalent = survival_df[(survival_df['is_cancer-0'] == True)].eid
print(f'Number of any cancer occurred before recruitment: {len(cancer_prevalent)}')
survival_df = survival_df.loc[~survival_df.eid.isin(cancer_prevalent), :]
print(f'Number of participants left: {len(survival_df.eid)}')
survival_df.label_class.value_counts()

In [None]:
# Calculate survival as a factor of age of diagnosis

survival_df[
    ['date_lfu', 'date_death', 'label_first_occurred_date', 'visit_date-0']
] = survival_df[
    ['date_lfu', 'date_death', 'label_first_occurred_date', 'visit_date-0']
].apply(
    pd.to_datetime, errors='coerce'
)
censoring_date = pd.to_datetime('29-02-2020', format='%d-%m-%Y')
survival_df[['event_', 'age_', 'obs_end_date']] = survival_df.apply(
    compute_survival_time_with_age_for_label,
    censoring_date=censoring_date,
    result_type='expand',
    axis=1,
)
print(survival_df.shape)
nonbaseline_cols = [
    col for col in survival_df.columns if col.endswith(('-1', '-2', '-3'))
]
survival_df.drop(nonbaseline_cols, axis='columns', inplace=True)
survival_df.rename(columns=lambda x: x.split('-')[0], inplace=True)
print(survival_df.shape)

In [None]:
# Recode categorical variables

survival_df['fasted'] = survival_df['fasted'].astype(float)
survival_df['ethnicity'] = survival_df['ethnicity'].apply(
    lambda x: 'unk' if pd.isnull(x) == True else ('white' if x == 1 else 'nonwhite')
)
survival_df['met_mins'] = (
    pd.qcut(survival_df.loc[:, 'met_mins'], q=5, labels=range(1, 6))
    .values.add_categories('unk')
    .fillna('unk')
)

survival_df.replace(
    {
        'redmeat_intake': {np.nan: 'unk'},
        'oily_fish_intake': {np.nan: 'unk'},
        'famhist_cancer': {np.nan: 'unk'},
        'edu_university': {np.nan: 'unk'},
        'regular_aspirin': {np.nan: 'unk'},
        'crc_screening': {np.nan: 'unk'},
        'health_rating': {np.nan: 'unk'},
        'alcohol': {np.nan: 'unk'},
        'smoke': {np.nan: 'unk'},
        'diseasehist_ibd': {np.nan: 'unk'},
        'diseasehist_diabetes': {np.nan: 'unk'},
        'diseasehist_cardiovascular': {np.nan: 'unk'},
        'diseasehist_anyliverbiliary': {np.nan: 'unk'},
    },
    inplace=True,
)

survival_df['ethnicity'] = pd.Categorical(survival_df['ethnicity'], categories=['white', 'nonwhite', 'unk'])
survival_df['redmeat_intake'] = pd.Categorical(survival_df['redmeat_intake'], categories=[0, 1, 2, 3, 4, 5, 'unk'])
survival_df['oily_fish_intake'] = pd.Categorical(survival_df['oily_fish_intake'], categories=[0, 1, 2, 3, 4, 5, 'unk'])
survival_df['famhist_cancer'] = pd.Categorical(survival_df['famhist_cancer'], categories=[False, True, 'unk'])
survival_df['diseasehist_ibd'] = pd.Categorical(survival_df['diseasehist_ibd'], categories=[False, True, 'unk'])
survival_df['diseasehist_cardiovascular'] = pd.Categorical(survival_df['diseasehist_cardiovascular'], categories=[False, True, 'unk'])
survival_df['diseasehist_diabetes'] = pd.Categorical(survival_df['diseasehist_diabetes'], categories=[False, True, 'unk'])
survival_df['diseasehist_anyliverbiliary'] = pd.Categorical(survival_df['diseasehist_anyliverbiliary'], categories=[False, True, 'unk'])
survival_df['edu_university'] = pd.Categorical(survival_df['edu_university'], categories=[False, True, 'unk'])
survival_df['regular_aspirin'] = pd.Categorical(survival_df['regular_aspirin'], categories=[False, True, 'unk'])
survival_df['crc_screening'] = pd.Categorical(survival_df['crc_screening'], categories=[False, True, 'unk'])
survival_df['health_rating'] = pd.Categorical(survival_df['health_rating'], categories=[4, 3, 2, 1, 'unk'])
survival_df['alcohol'] = pd.Categorical(survival_df['alcohol'], categories=[0, 1, 2, 3, 4, 5, 6, 'unk'])
survival_df['smoke'].replace(4, 'unk', inplace=True)
survival_df['smoke'] = pd.Categorical(survival_df['smoke'], categories=[0, 1, 2, 3, 'unk'])
survival_df['met_mins'] = pd.Categorical(survival_df['met_mins'], categories=[1, 2, 3, 4, 5, 'unk'])
survival_df['regular_statin'] = pd.Categorical(survival_df['regular_statin'], categories=[False, True])
survival_df['sex'] = pd.Categorical(survival_df['sex'], categories = [0,1])

In [None]:
# Select columns
# Removed 'crc_screening', since highly correlated with label

selected_cols = [
    'age','sex','ethnicity','townsend','alcohol','smoke','fasted','redmeat_intake','oily_fish_intake',
    'famhist_cancer','edu_university','regular_aspirin','regular_statin','health_rating','diseasehist_ibd',
    'diseasehist_cardiovascular', 'diseasehist_diabetes','diseasehist_anyliverbiliary','met_mins','hgrip',
    'tlr','whr','bmi','height','met_rate','impedance','sleep_dur','sbp','dbp','pulse','hgb','hct','wbc',
    'rbc','plt','lym','mcv','mono','neut','eos','baso','n_rbc','reti','u_sodium','u_potas','u_cr','apoa',
    'apob','chol','hdl','ldl','tgly','urea','crp','tprotein','glu','phos','alb','alp','alt','ast','ggt',
    'urate','d_bil','t_bil','shbg','igf1','vitd','cysc','calc','hba1c','tst',
]

df = survival_df.loc[:,['event_', 'age_'] + [col for col in survival_df.columns if col.split('-')[0] in selected_cols]]

In [None]:
# Remove NaNs

print(df.shape)
print(f'Number of rows with missing values: {df.isna().any(axis=1).sum()}')
df.isna().sum(axis=0).sort_values(ascending=False).head(20)
df.dropna(inplace=True)
print(df.shape)
print(f'Number of rows with missing values: {df.isna().any(axis=1).sum()}')

df.event_.value_counts().to_dict()

In [None]:
# Remove outliers based on percentiles

continuous_vars = [
    'hgrip','tlr','whr','height','met_rate','impedance','sleep_dur','sbp','dbp','pulse','bmi','hgb','hct',
    'wbc','rbc','plt','lym','mcv','mono','neut','eos','baso','n_rbc','reti','u_sodium','u_potas','u_cr',
    'apoa','apob','chol','hdl','ldl','tgly','urea','crp','tprotein','glu','phos','alb','alp','alt','ast',
    'ggt','urate','d_bil','t_bil','shbg','igf1','vitd','cysc','calc','hba1c','tst',
]
outliers = []
for i, col in enumerate(continuous_vars):
    outliers_ = detect_outliers(
        df, col, method='percentile', percentile_threshold=0.001
    )
    outliers += list(outliers_)

outliers = np.unique(outliers)
print(f'Number of outliers: {len(outliers)}')

df.drop(outliers, axis='index', inplace=True)
print(df.shape)

In [None]:
df.event_.value_counts().to_dict()

In [None]:
df.to_csv(s3_path + 'file3.csv', index=False)

### Forward feature selection

Fitting each variable separately on the training dataset, and selecting variables that have a p-value<0.10 (more liberal threshold)

In [None]:
X_train, X_test = train_test_split(
    df, test_size=0.2, random_state=1, stratify=df['event_']
)

cph = CoxPHFitter()
cols = df.drop(['event_', 'age_'], axis=1).columns.to_list()

mdl_name = []
var_keep = []
c_idx = []
aic = []
p_val = []
var_hr = []
var_se = []
var_pval = []

In [None]:
for c in cols:
    cph.fit(
        X_train, duration_col='age_', event_col='event_', formula=c, show_progress=False
    )
    mdl_name.append(c)
    c_idx.append(round(cph.concordance_index_, 4))
    aic.append(round(cph.AIC_partial_, 2))
    summary = cph.summary['p'].to_dict()
    p_val.append(round(min(list(summary.values())), 3))
    print('Model:', c, 'C-index:', c_idx[-1], 'AIC:', aic[-1], 'p:', p_val[-1])
    if p_val[-1] < 0.1:
        var_keep.append(c)
        var_hr.append(cph.summary['coef'][0])
        var_se.append(1.96 * (cph.summary['se(coef)'][0]))
        var_pval.append(p_val[-1])

In [None]:
# Get univariate parameters for the plots

var_keep = [v for _, v in sorted(zip(var_hr, var_keep))]  # put in descending order
var_se = [v for _, v in sorted(zip(var_hr, var_se))]
var_pval = [v for _, v in sorted(zip(var_hr, var_pval))]
var_hr.sort()

a = []
for v in var_keep:
    if '[' in v:
        v = v[: v.index('[')]
    a.append(v)
var_keep = a
var_names = [rename_mapping[v] for v in var_keep]

In [None]:
# Plot log(HR) of the selected features in univariate model

theme = matplotlib.colors.LinearSegmentedColormap.from_list(
    '', ['blue', 'gainsboro', 'red']
)
a = [theme(1.0 * i / len(var_keep)) for i in range(len(var_keep))]
a = [list(i[:3]) for i in a]

fig, ax = plt.subplots(figsize=(2.5, 10))
for i in range(len(var_hr)):
    plt.plot([var_hr[i] - var_se[i], var_hr[i] + var_se[i]], [i, i], color=a[i])
ax.set_yticks(np.arange(len(var_hr)))
ax.set_yticklabels(var_keep)
plt.xlabel('log(HR) 95% CI')
plt.axvline(x=0, color='silver', linestyle='--')
plt.scatter(var_hr, range(len(var_hr)), s=60, c=a)
ax.set_yticklabels(var_names)
plt.savefig('./figures/paper_all_hazard_ratios.jpg', dpi=400, bbox_inches='tight')
plt.show()

In [None]:
var_keep = var_keep + ['event_', 'age_']
X_train = X_train[var_keep].copy(deep=True)
X_test = X_test[var_keep].copy(deep=True)

### VIF - Remove correlated features

Find and remove variables correlated with each other, to reduce multicollinearity, and clarify the contribution of each predictor to the model based on variance inflation factor >10.

In [None]:
df = X_train.copy(deep=True)
df.replace('unk', np.NaN, inplace=True)
df.replace(['False', 'True'], [0, 1], inplace=True)

for c in df:
    df[c] = df[c].astype(float)
df.dropna(inplace=True)

RBC correlated with HCT, TBIL with DBIL, HDL with APOA, sex with TST, MET_RATE, height and impedance.

In [None]:
# Calculate VIF

cols = [
    'calc','cysc','redmeat_intake','hdl','regular_aspirin','urea','age','oily_fish_intake','chol',
    'lym','shbg','vitd','u_cr','urate','u_potas','u_sodium','ggt','alt','dbp','pulse','crp','t_bil',
    'igf1','bmi','hgrip','mono','wbc','met_mins','baso','hgb','tgly','smoke','health_rating','reti',
    'famhist_cancer','alcohol','rbc','fasted','sex','tlr','whr',
]
var_keep = cols

vif_df = pd.DataFrame()
vif_df['variable'] = df[cols].columns

X = df[cols].copy(deep=True)
X['intercept'] = 1
vif = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
vif_df['vif'] = vif[:-1]
vif_df.sort_values(by='vif', ascending=False)[:15]

In [None]:
# Plot dendrogram

cols = df.drop(['age_', 'event_'], axis=1).columns
col_names = [rename_mapping[c] for c in cols]

set_link_color_palette(
    [
        'darkmagenta',
        'navy',
        'royalblue',
        'lightseagreen',
        'limegreen',
        'gold',
        'darkorange',
        'orangered',
        'crimson',
    ]
)  
corrs = df.drop(['age_', 'event_'], axis=1).corr()
plt.figure(figsize=(2, 10))
plt.xlabel('Ward distance', fontsize=11)
dissimilarity = 1 - abs(corrs)
Z = linkage(squareform(dissimilarity), 'ward')

R1 = dendrogram(
    Z,
    labels=cols,
    orientation='left',
    color_threshold=1.1,
    leaf_font_size=10,
    count_sort='ascending',
    above_threshold_color='silver',
)
# leaf_rotation=90, distance_sort='descending',);
R = dendrogram(
    Z,
    labels=col_names,
    orientation='left',
    color_threshold=1.1,
    leaf_font_size=10,
    count_sort='ascending',
    above_threshold_color='silver',
)  # distance_sort='descending',

plt.savefig('./figures/paper_dendrogram.jpg', dpi=400, bbox_inches='tight')
plt.show()

In [None]:
# Plot r-map in the leaf order of the dendrogram

new_order = R1['ivl']
new_order.reverse()
col_names = [rename_mapping[c] for c in new_order]

corrs = df[new_order].corr()
pval = df[new_order].corr(method=lambda x, y: sp.pearsonr(x, y)[1]) - np.eye(
    *corrs.shape
)

mask = np.triu(np.ones_like(corrs, dtype=bool))
plt.figure(figsize=(14, 14))
sns.heatmap(
    corrs,
    annot=False,
    annot_kws={'size': 9},
    fmt='.2f',
    mask=mask,
    xticklabels=col_names,
    yticklabels=col_names,
    square=True,
    cbar_kws={'shrink': 0.5},
    cmap='bwr',
    vmin=-0.6,
    vmax=0.6,
).set(title='Intercorrelations - rmap')

plt.savefig('./figures/paper_rmap.jpg', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
var_keep = var_keep + ['event_', 'age_']
X_train = X_train[var_keep].copy(deep=True)
X_test = X_test[var_keep].copy(deep=True)

### Backward elimination

Start by fitting the all of the selected features, and remove features that don't significantly contribute to the model using p > 0.05, starting with the lowest significance in an ascending order.

In [None]:
df = X_train.copy(deep=True)

In [None]:
# Initiate

cols = df.drop(
    ['event_', 'age_'], axis=1
).columns  # all features from the forward feature selection step
mdl_cols = cols.to_list()
mdl_formula = ' + '.join(mdl_cols)

cph = CoxPHFitter()
cph.fit(
    df,
    duration_col='age_',
    event_col='event_',
    formula=mdl_formula,
    show_progress=False,
)
print('Model:', mdl_formula)
print(
    '-- C-index:', round(cph.concordance_index_, 6), 'AIC:', round(cph.AIC_partial_, 2)
)

summary = cph.summary['p'].to_dict()
mdl_vars = list(summary.keys())
mdl_pvals = np.array(list(summary.values()))

vnames = []
min_p = []
for v in mdl_vars:  # get the minimum p-value of each variable
    if '[' in v:
        i = v.index('[')
        vname = v[:i]
    else:
        vname = v
    vnames.append(vname)
    idx = [mdl_vars.index(i) for i in mdl_vars if i.startswith(vname)]
    min_p.append(np.min(mdl_pvals[idx]))

while (
    np.max(min_p) > 0.0499
):  # whilst there are non-significant variables not accepting marginal effects

    idx = np.argmax(min_p)
    mdl_cols.remove(
        vnames[idx]
    )  # remove that variable from the list and rerun the model
    print('Removing', vnames[idx])

    mdl_formula = ' + '.join(mdl_cols)

    cph.fit(
        df,
        duration_col='age_',
        event_col='event_',
        formula=mdl_formula,
        show_progress=False,
    )
    print('Model:', mdl_formula)
    print(
        '-- C-index:',
        round(cph.concordance_index_, 6),
        'AIC:',
        round(cph.AIC_partial_, 2),
    )
    summary = cph.summary['p'].to_dict()
    mdl_vars = list(summary.keys())
    mdl_pvals = np.array(list(summary.values()))

    vnames = []
    min_p = []
    for v in mdl_vars:
        if '[' in v:
            i = v.index('[')
            vname = v[:i]
        else:
            vname = v
        vnames.append(vname)
        idx = [mdl_vars.index(i) for i in mdl_vars if i.startswith(vname)]
        min_p.append(np.min(mdl_pvals[idx]))

In [None]:
cph.print_summary(decimals=4)

In [None]:
# Calculate unadjusted HRs for comparison against the multivariate model

cindex = []
aic = []
hr = []
ci_lower = []
ci_upper = []
pval = []
vname = []

for c in mdl_cols:
    cph = CoxPHFitter()
    cph.fit(X_train, duration_col='age_', event_col='event_', formula=c, show_progress=False)
    varnames = list((cph.summary['p'].to_dict()).keys())
    for i in range(len(cph.summary['p'])):
        hr.append(round(cph.summary['exp(coef)'][i],2))
        ci_lower.append(round(cph.summary['exp(coef) lower 95%'][i],2))
        ci_upper.append(round(cph.summary['exp(coef) upper 95%'][i],2))
        pval.append(round(cph.summary['p'][i],5))
        vname.append(varnames[i])
        cindex.append(round(cph.concordance_index_,3))
        aic.append(round(cph.AIC_partial_,2))
    
univariate_df = pd.DataFrame()
univariate_df = univariate_df.assign(Covariate=vname, HR=hr, CI_lower=ci_lower, CI_upper=ci_upper, AIC=aic, C_index=cindex, p=pval)
univariate_df

In [None]:
# Test performance on the test set

cph = CoxPHFitter()
cph.fit(
    X_test,
    duration_col='age_',
    event_col='event_',
    formula=' + '.join(mdl_cols),
    show_progress=False,
)
print(
    '-- C-index:', round(cph.concordance_index_, 6), 'AIC:', round(cph.AIC_partial_, 2)
)

In [None]:
mdl_cols = mdl_cols + ['event_', 'age_']
X_train = X_train[mdl_cols].copy(deep=True)
X_test = X_test[mdl_cols].copy(deep=True)

### Plot results

In [None]:
df = X_train.copy(deep=True)

cph = CoxPHFitter()
cols = df.drop(['event_', 'age_'], axis=1).columns
mdl_cols = cols.to_list()

cph.fit(
    df,
    duration_col='age_',
    event_col='event_',
    formula=' + '.join(mdl_cols),
    show_progress=False,
)  # entry_col='age',
summary = cph.summary['p'].to_dict()
mdl_vars = list(summary.keys())
mdl_pvals = np.array(list(summary.values()))

In [None]:
cph.print_summary(decimals=3)

In [None]:
# Default plot from lifelines log(HR) plot

plt.figure(figsize=(5, 5))
ax = cph.plot()  # hazard_ratios=True
plt.show()

In [None]:
# HR plot

plt.figure(figsize=(5, 5))
ax = cph.plot(hazard_ratios=True)
plt.show()

In [None]:
# Get stats for the plots

hr = cph.summary['coef'].to_dict()  # get values to plot
mdl_vars = list(hr.keys())
hr = cph.summary['coef']
se = cph.summary['se(coef)']
pval = cph.summary['p']

idx = [
    i for i in range(len(pval)) if pval[i] < 0.05 and 'T.unk' not in mdl_vars[i]
]  # threshold by p
pval = [pval[i] for i in idx]
hr = [hr[i] for i in idx]
se = np.array([1.96 * se[i] for i in idx])
mdl_vars = [mdl_vars[i] for i in idx]

mdl_vars = [v for _, v in sorted(zip(hr, mdl_vars))]  # put in descending order
se = [v for _, v in sorted(zip(hr, se))]
pval = [v for _, v in sorted(zip(hr, pval))]
hr.sort()

a = []
for v in mdl_vars:
    if '[' in v:
        v = v[: v.index('[')]
    a.append(v)
mdl_vars = a
var_names = [rename_mapping[v] for v in mdl_vars]

In [None]:
var_names = [
    'Age',
    'Urea',
    'Cholesterol',
    'ALT',
    'SHBG',
    'Pulse',
    'Triglycerides',
    'Basophil %',
    'Family history of cancer',
    'Sex',
    'Alcohol intake 1-3 u/pm',
    'Alcohol intake 3-4 u/pw',
    'Alcohol intake 5-7 u/pw',
    'Waist-to-hip ratio',
]

In [None]:
# Plot using a risk gradient

theme = matplotlib.colors.LinearSegmentedColormap.from_list(
    '', ['blue', 'gainsboro', 'red']
)
a = [theme(1.0 * i / len(mdl_vars)) for i in range(len(mdl_vars))]
a = [list(i[:3]) for i in a]

fig, ax = plt.subplots(figsize=(5, 5))
for i in range(len(hr)):
    plt.plot([hr[i] - se[i], hr[i] + se[i]], [i, i], color=a[i])
ax.set_yticks(np.arange(len(hr)))
ax.set_yticklabels(mdl_vars)
plt.xlabel('log(HR) 95% CI', fontsize=13)
plt.axvline(x=0, color='silver', linestyle='--')
plt.scatter(hr, range(len(hr)), s=60, c=a)
ax.set_yticklabels(var_names, fontsize=13)
plt.xticks(fontsize=13)
plt.savefig('./figures/paper_hrplot.jpg', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Changing categories in string format to integers (e.g. '2' to 2)

df.replace(
    ['False', 'True', '0', '1', '2', '3', '4', '5', '6', 'unk'],
    [0, 1, 0, 1, 2, 3, 4, 5, 6, 9],
    inplace=True,
)
df['famhist_cancer'] = pd.Categorical(df['famhist_cancer'], categories=[0, 1, 9])
df['alcohol'] = pd.Categorical(df['alcohol'], categories=[0, 1, 2, 3, 4, 5, 6, 9])

In [None]:
# Partial covariate survival plots

theme = matplotlib.colors.LinearSegmentedColormap.from_list(
    '', ['blue', 'gainsboro', 'red']
)
cph = CoxPHFitter()
cph.fit(df, duration_col='age_', event_col='event_', show_progress=False)

ax = cph.plot_partial_effects_on_outcome(
    covariates=['whr'],
    values=[0.6, 0.7, 0.8, 0.9, 1, 1.1, 1.2],
    cmap=theme,
    plot_baseline=False,
    figsize=(2.5, 5),
)
plt.xlim([60, 85])
plt.ylim([0.8, 1])
plt.ylabel('% Healthy', fontsize=13)
plt.xlabel('Age of diagnosis', fontsize=12)
plt.title('Waist-to-hip ratio')
ax.get_legend().remove()
plt.xticks(fontsize=11)
plt.yticks([0.80, 0.85, 0.90, 0.95, 1.00], fontsize=11)
plt.savefig('./figures/paper_survival_waist_to_hip.jpg', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(5, 5))
ax = cph.plot_partial_effects_on_outcome(
    covariates=['alcohol'],
    values=[0, 1, 2, 3, 4, 5, 6],
    cmap=theme,
    plot_baseline=False,
    figsize=(2.5, 5),
)
plt.xlim([60, 85])
plt.ylim([0.8, 1])
plt.ylabel('% Healthy', fontsize=13)
plt.xlabel('Age of diagnosis', fontsize=12)
plt.title('Alcohol intake')
ax.get_legend().remove()
plt.xticks(fontsize=11)
plt.yticks([0.80, 0.85, 0.90, 0.95, 1.00], fontsize=11)
plt.savefig('./figures/paper_survival_alcohol.jpg', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(5, 5))
ax = cph.plot_partial_effects_on_outcome(
    covariates=['sex'], values=[0, 1], cmap=theme, plot_baseline=False, figsize=(2.5, 5)
)
plt.xlim([60, 85])
plt.ylim([0.8, 1])
plt.ylabel('% Healthy', fontsize=13)
plt.xlabel('Age of diagnosis', fontsize=12)
plt.title('Sex')
ax.get_legend().remove()
plt.xticks(fontsize=11)
plt.yticks([0.80, 0.85, 0.90, 0.95, 1.00], fontsize=11)
plt.savefig('./figures/paper_survival_sex.jpg', dpi=300, bbox_inches='tight')
plt.show()