# Survival analysis

This notebook corresponds to step 3 of the WhARIO pipeline (cf. README.md). The notebook format allows to plot the Kaplan-Meier curves and select features or seeds based on the results of various experiments.

## Imports and useful functions

In [None]:
from functools import partial
import os
import itertools
import pickle
from random import randint
from statistics import NormalDist

from h5py import File
import lifelines
from lifelines import KaplanMeierFitter
from lifelines.fitters.coxph_fitter import CoxPHFitter
from lifelines.statistics import (
    logrank_test, pairwise_logrank_test
)
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.feature_selection import f_regression
from sklearn.model_selection import train_test_split, KFold
from tqdm import tqdm

In [None]:
from multiprocess import Pool

In [None]:
# This import is used to ensure that the various processes
# created by the call to the Pool() class from the
# multiprocessing library do not freeze. If fit_model_cv
# is defined inside the notebook, we observed the script
# would hang randomly during execution.
from workers import fit_model_cv

In [None]:
def confidence_interval(data, confidence=0.95):
    dist = NormalDist.from_samples(data)
    z = NormalDist().inv_cdf((1 + confidence) / 2.)
    h = dist.stdev * z / ((len(data) - 1) ** .5)
    left = round(dist.mean - h, 4)
    right = round(dist.mean + h, 4)
    return left, right

In [None]:
def load_dataset(filename):
    '''Load the table of clinical data.'''
    usecols = [
        'case_id', 'slide_id', 'PD-L1_tumor', 'recist',
        'io_response', 'os', 'os_months', 'os_censored',
        'center'
    ]
    df = pd.read_excel(f'./{filename}', usecols=usecols)
    # Filter out patients whose survival information is missing
    df = df.dropna(subset=['os_months'])
    df = df[df['os_months'] > 0]
    # Add event observation (opposite of censorship)
    df = df.assign(event_observed=1-df['os_censored'])
    # Filter out patients whose Tumor Proportion Score (TPS)
    # is unknown
    df = df.dropna(subset=['PD-L1_tumor'])
    return df

In [None]:
def get_num_clusters(df_clini, desc_path):
    '''Retrieve the total number of clusters'''
    case_id_for_count = df_clini.loc[
        df_clini.index[0], 'case_id'
    ]
    desc_filepath = os.path.join(
        desc_path, f"{case_id_for_count}.pkl"
    )
    with open(desc_filepath, 'rb') as f:
        desc = pickle.load(f)
    return len(desc)

In [None]:
def get_patient_descriptors(
        df_clini, desc_path, apply_ratio=False):
    """
    Load patient descriptors for each case in df_clini.
    
    Parameters
    ----------
    df_clini: pandas.DataFrame
        Contains the cases with clinical information.
    desc_path: str
        Path to the files containing patient descriptors.
    apply_ratio: bool
        If true, divide each row by the colomn-wise sum.

    Returns
    ----------
    df_cl: pandas.DataFrame
        The case-wise latent representations.
    """
    case_ids = []
    case_descs = []
    to_remove = []
    num_clusters = get_num_clusters(df_clini, desc_path)
    for case_id in df_clini['case_id'].drop_duplicates():
        desc_filepath = os.path.join(
            desc_path, f"{case_id}.pkl"
        )
        if os.path.exists(desc_filepath):
            case_ids.append(case_id)
            with open(desc_filepath, 'rb') as f:
                desc = pickle.load(f)
            if apply_ratio:
                desc = (
                    desc
                    /
                    (np.sum(desc, axis=1).reshape(-1, 1)+1e-11)
                )
            case_descs.append(desc.ravel())
        else:
            to_remove.append(case_id)

    cl_dict = dict(zip(case_ids, case_descs))
    columns = [
        f'h{i}-{j}'
        for i in range(num_clusters)
        for j in range(num_clusters+1)
    ]
    df_cl = pd.DataFrame.from_dict(
        cl_dict, orient='index', columns=columns
    )
    df_cl.index.name = 'case_id'
    return df_cl, to_remove

In [None]:
def join_df_w_cl(df, df_cl, to_remove):
    '''
    Join cluster-wise latent representations with clinical
    data.
    '''
    df = df.drop_duplicates(subset='case_id')
    df = df[~df['case_id'].isin(to_remove)]
    df.reset_index(drop=True, inplace=True)
    df_w_cl = df.join(df_cl, on='case_id')
    return df_w_cl

In [None]:
def plot_km_curves(df, ax, pval=None, title=None, **km_kwargs):
    '''Kaplan-Meier curves plotting'''
    fitters = {
        group: KaplanMeierFitter()
        for group in pd.unique(df["risk_group"])
    }
    colors = {'high': '#ff5647', 'low': '#69adff'}
    for group in fitters.keys():
        durations = df[df['risk_group'] == group]['os_months']
        event_observed = (
            df[df['risk_group'] == group]['event_observed']
        )
        fitters[group].fit(durations, event_observed)
        fitters[group].plot(
            ax=ax, label=group, color=colors[group],
            **km_kwargs
        )
    ax.set_xlabel('Months after treatment start')
    ax.set_ylabel('Probability of survival')

    if pval is not None:
        ax.text(
            0.8, 0.75,
            f'p = {pval:.3e}',
            transform=ax.transAxes
        )

    if title is not None:
        ax.set_title(f"Stratification by median risk")

In [None]:
def log_rank_pvalue(df):
    event_durations = df['os_months']
    groups = df['risk_group']
    event_observed = df['event_observed']
    statsresults = pairwise_logrank_test(
        event_durations, groups, event_observed
    )
    return statsresults.summary.loc['high', 'p'].item()

In [None]:
def assign_risk_groups(df, quantile=0.5):
    thresh = df['predicted_risk'].quantile(q=quantile)
    df['risk_group'] = df['predicted_risk'].map(
        lambda x: 'low' if x < thresh else 'high'
    )
    return df

In [None]:
def inference(
        model, test_set, *covariates,
        optimal_risk=None, quantile=0.5):
    '''
    Computes C-index and log-rank test on the test-set.
    '''
    test_set_ = test_set.loc[
        :,
        list(covariates)+['os_months', 'event_observed']
    ]
    test_c_index = model.score(
        test_set_,
        scoring_method='concordance_index'
    )
    if isinstance(model, CoxPHFitter):
        test_risks = model.predict_log_partial_hazard(test_set_)
    test_set = test_set.assign(predicted_risk=test_risks)
    if optimal_risk is None:
        optimal_risk = (
            test_set['predicted_risk'].quantile(q=quantile)
        )
    test_set['risk_group'] = test_set['predicted_risk'].map(
        lambda x: 'low' if x < optimal_risk else 'high'
    )
    return test_set, test_c_index

In [None]:
def has_no_features(df, *features):
    '''
    Tests which cases in df have none of the features. Consequently,
    the predicted risk of these cases is no more than the baseline
    hazard.
    '''
    check = df[['case_id']+list(features)]
    check = check.set_index('case_id')
    feat_sum = check.sum(axis=1)
    return check[feat_sum == 0].index

## Load files

In [None]:
clini = load_dataset(
    './clinical_data.xlsx'
)
desc_path = "./patient_descriptors/"
cluster_distrib_path = './cluster_distrib/'
clini_os = filter_os_months(clini)

In [None]:
# Load cluster adjacency matrics and
# derive patient descriptors
num_clusters = get_num_clusters(clini_os, desc_path)
apply_ratio = True
df, to_remove = get_patient_descriptors(
    clini_os, desc_path, apply_ratio=apply_ratio
)
df = join_df_w_cl(clini_os, df, to_remove)
df.reset_index(drop=True, inplace=True)

In [None]:
# Feature selection process: add features to the selection
# set as long as the metric improves.
# The metric is - c / log10(p) as explained in the paper.
all_features = [
    f'h{i}-{j}'
    for i in range(num_clusters)
    for j in range(num_clusters+1)
]
feature_set = all_features.copy()
selection = []
best_score = np.inf
while True:
    subset_metrics = []
    for i, feat in enumerate(feature_set):
        all_seeds = []
        summaries = []
        seed = [randint(0, 10_000) for _ in range(30)]
        args = [
            (df, selection+[feat], 0.5, s)
            for s in seed
        ]
        with Pool() as pool:
            res = pool.starmap(fit_model_cv, args)
        summaries = [r[2] for r in res]
        summaries = pd.concat(summaries, ignore_index=True)
        summaries = summaries.drop(
            columns=['c_index_CI', 'c_index_std']
        )
        summaries['ci_to_log_p_ratio'] = (
            summaries['c_index_mean']
            / (-np.log10(summaries['log-rank pvalue']))
        )
        summaries['score'] = (
            1 / (2*summaries['c_index_mean'])
            + summaries['ci_to_log_p_ratio']
        )
        mean_stats = summaries.mean(axis=0).to_frame().T
        mean_stats.index = [feat]
        subset_metrics.append(mean_stats)
    subset_metrics = pd.concat(subset_metrics)
    best = subset_metrics.sort_values(
        by=['score']).iloc[0]
    best_ratio = best['score']
    best_feat = best.name
    if best_ratio < best_score:
        best_score = best_ratio
        # add feature to the selected set
        selection.append(best_feat)
        # remove feature from the exploration space
        feature_set.remove(best_feat)
    else:
        break

In [None]:
def repeat_model_fit(df, feature_set, nseeds=30):
    '''
    Fits a Cox PH model on a given feature multiple times
    to find a good seed.
    
    Parameters
    ----------
    df: pandas.DataFrame
        The data frame containing the features
    feature_set: list
        The list of features to be kept in df
    nseeds: int
        The number of repetitions with a different seed
    
    Returns
    ----------
    summaries: pandas.DataFrame
        A data frame summarizing the results obtained for
        each experiment (C index, p-values, etc...)
    '''
    summaries = []
    for _ in range(30):
        seed = randint(0, 10_000)
        _, _, summary, _ = fit_model_cv(
            df, feature_set, seed=seed
        )
        summary.index = [seed]
        summaries.append(summary)

    summaries = pd.concat(summaries)
    summaries = summaries.drop(columns=['c_index_CI'])
    summaries['ci_to_log_p_ratio'] = (
        summaries['c_index_mean']
        / (-np.log10(summaries['log-rank pvalue']))
    )
    summaries['score'] = (
        1 / (2*summaries['c_index_mean'])
        + summaries['ci_to_log_p_ratio']
    )
    mean_stats = summaries.mean(axis=0).to_frame().T
    mean_stats.index = ['avg']
    min_stats = summaries.min(axis=0).to_frame().T
    min_stats.index = ['min']
    max_stats = summaries.max(axis=0).to_frame().T
    max_stats.index = ['max']
    summaries = pd.concat(
        (summaries, mean_stats, min_stats, max_stats)
    )
    
    return summaries

In [None]:
def km_curves_and_summary(df, feature_set, seed=0):
    '''
    Plot the Kaplan-Meier curves for a specific seed, and print the
    metrics' summary of the model.
    
    Parameters
    ----------
    df: pandas.DataFrame
        The data frame containing the features
    feature_set: list
        The list of features to be kept in df
    nseeds: int
        The number of repetitions with a different seed
        
    Returns
    ----------
    df_risk: pandas.DataFrame
        A data frame summarizing the results obtained for
        each experiment (C index, p-values, etc...)
    '''
    # Plot the Kaplan-Meier curves for a specific seed
    df_risk, results, summary, models = fit_model_cv(
        df, feature_set, seed=seed
    )
    print(results.to_markdown(floatfmt=".3f"))
    print()
    summary.index = [seed]
    summary.index.name = 'seed'
    print(summary.to_markdown(floatfmt=".4f"))
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_km_curves(
        df_risk, ax, pval=summary.loc[seed, 'log-rank pvalue'],
        show_censors=True
    )
    # Print p-values for the features used by the model
    cox_pvals = []
    for i in range(5):
        cox_pvals.append(models[i].summary['p'].to_frame())
    print(pd.concat(cox_pvals, axis=1).to_markdown(floatfmt=".5f"))
    
    return df_risk

In [None]:
def compute_hazard_ratio(df_risk):
    '''
    Based on the risk groups returned by the Cox PH model,
    compute the Hazard Ratios (HRs).
    '''
    for_hr_comp = df_risk[
        ['os_months', 'event_observed', 'risk_group']
    ].copy()
    for_hr_comp['risk_group'] = for_hr_comp['risk_group'].map(
        {'low': 0, 'high': 1}
    )
    cph_hr = CoxPHFitter()
    cph_hr.fit(
        for_hr_comp, duration_col='os_months',
        event_col='event_observed'
    )
    print(cph_hr.summary)

In [None]:
summaries = repeat_model_fit(df, selection)
print(summaries.to_markdown(floatfmt=".4f"))

The previous print shows the results obtained with every `nseeds` value. One can select from the latter the best seed to plot the KM curves and print the survival metrics (C-index, p-values of the features, Hazard Ratio).

In [None]:
seed = 1  # enter the chosen seed here
km_curves_and_summary(df, selection, seed=seed)

## Keep significant coefficients

At this point, from the results of the experiment above, one can remove the features for which the p-value is above the 0.05 threshold.

In [None]:
# Remove features with p > 0.05
feature_set = [] # enter the features here

In [None]:
summary = repeat_model_fit(df, feature_set)
print(summary.to_markdown(floatfmt=".4f"))

In [None]:
seed = 1  # enter the chosen seed here
df_risk = km_curves_and_summary(df, feature_set, seed=seed)

In [None]:
compute_hazard_ratio(df_risk)

## Combine features with PD-L1

### PD-L1 alone

In [None]:
seed = 1  # enter the chosen seed here
df_risk = km_curves_and_summary(df, ['PD-L1_tumor'], seed=seed)

In [None]:
compute_hazard_ratio(df_risk)

### Combination

In [None]:
summary = repeat_model_fit(df, feature_set)
print(summary.to_markdown(floatfmt=".4f"))

In [None]:
seed = 1  # enter the chosen seed here
df_risk = km_curves_and_summary(
    df, ['PD-L1_tumor']+feature_set, seed=seed
)

In [None]:
compute_hazard_ratio(df_risk)