In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn import set_config
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator


from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder


from time_to_event.utils import dist_histogram, countplot



plt.rcParams["figure.figsize"] = (12, 8)
plt.style.use("ggplot")

## About the data

We are working with the well known Veteran's Lung Cancer data, collected and made public for research a few years ago. This dataset contains information relevant to a clinical study that monitored cancer patients and registered the probability of survival after receiving treatment. It contains the following columns.

- Treatment: denotes the type of lung cancer treatment; standard and test drug.
- Celltype: denotes the type of cell involved; squamous, small cell, adeno, large.
- Karnofsky_score: is the [Karnofsky score](http://www.npcrc.org/files/news/karnofsky_performance_scale.pdf).
- Diag: is the time since diagnosis in months.
- Age: is the age in years.
- Prior_Therapy: denotes any prior therapy; none or yes.
- Status: denotes the status of the patient as dead or alive; dead or alive.
- Survival_in_days: is the survival time in days since the treatment.

The key question to answer in these type of analysis is: "What's the probability of a given patient to be alive in time _t_? 

This is a complicated problem as for example, if we contact patient X at time _t5_, we are certain that he/she was alive at _t5_, but we can't be certain this patient will still be alive at _t10_. This is why we depend heavily in probability and accept the fact that we are working with censored data. 


In [None]:
data_x, data_y = load_veterans_lung_cancer()

In [None]:
data_x.head()

In [None]:
# the Y data has a peculiar format. Usually, it comes with a label and a time when a particular event has happened.
data_y[:5]

In [None]:
# Creating an EDA Df for simplicity
eda_df = pd.concat(
   [ 
       data_x,
       pd.Series(data_y["Status"], name="Status"),
       pd.Series(data_y["Survival_in_days"], name="Survival_in_days")
    ],
    axis=1
)

## EDA

Simple data exploration. You can expand on this for a deeper analysis.

In [None]:
eda_df.describe()

In [None]:
# Celltype distribution
countplot(
    df=eda_df,
    x='Celltype', 
    hue="Status",
    stat="percent"
)


In [None]:
# Treatment distribution
countplot(
    df=eda_df,
    x='Treatment', 
    hue="Status",
    stat="percent"
)


In [None]:
dist_histogram(eda_df, "Survival_in_days", color="r")

In [None]:
dist_histogram(eda_df, "Age_in_years", color="b")

In [None]:
dist_histogram(eda_df, "Age_in_years", color="b", hue_="Status")

In [None]:
dist_histogram(eda_df, "Months_from_Diagnosis", color="b", hue_="Status")

In [None]:
g = sns.PairGrid(eda_df, hue="Status")
g.map_diag(sns.histplot)
g.map_offdiag(sns.scatterplot)
g.add_legend()
plt.show()

We can see that regretfully, most patients died during the study. For our analysis, this means that most of our data in uncensored as we know when the **event == death** happened.

## Simple survival probability function.

One of the most common methods to analyse survival probability is the [Kaplan Meier Estimator](https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator)

In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment

In [None]:

time, survival_prob, conf_int = kaplan_meier_estimator(
    data_y["Status"], data_y["Survival_in_days"], conf_type="log-log"
)
plt.step(time, survival_prob, where="post")
plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
plt.ylim(0, 1)
plt.ylabel(r"est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")


We can see that most patients died at _t200_ (200 days). The elbow curve practically flattens at _t400_

## Adding other variables

In [None]:
for treatment_type in ("standard", "test"):
    mask_treat = data_x["Treatment"] == treatment_type
    time_treatment, survival_prob_treatment, conf_int = kaplan_meier_estimator(
        data_y["Status"][mask_treat],
        data_y["Survival_in_days"][mask_treat],
        conf_type="log-log",
    )

    plt.step(time_treatment, survival_prob_treatment, where="post", label=f"Treatment = {treatment_type}")
    plt.fill_between(time_treatment, conf_int[0], conf_int[1], alpha=0.25, step="post")

plt.ylim(0, 1)
plt.ylabel(r"est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")

In [None]:
for value in data_x["Celltype"].unique():
    mask = data_x["Celltype"] == value
    time_cell, survival_prob_cell, conf_int = kaplan_meier_estimator(
        data_y["Status"][mask], data_y["Survival_in_days"][mask], conf_type="log-log"
    )
    plt.step(time_cell, survival_prob_cell, where="post", label=f"{value} (n = {mask.sum()})")
    plt.fill_between(time_cell, conf_int[0], conf_int[1], alpha=0.25, step="post")

plt.ylim(0, 1)
plt.ylabel(r"est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")
plt.show()

## Multivariate Model

We can build linear estimators similar to a logistic regression, to pin point the drivers that boost or hinder survival in a patient. 

In [None]:

data_x_numeric = OneHotEncoder().fit_transform(data_x)
data_x_numeric.head()

In [None]:
# let's create 4 new patients. Assuming they just joined the trail. We want to estimate their survival probability

x_new = pd.DataFrame.from_dict(
    {
        1: [65, 0, 0, 1, 60, 1, 0, 1],
        2: [65, 0, 0, 1, 60, 1, 0, 0],
        3: [65, 0, 1, 0, 60, 1, 0, 0],
        4: [65, 0, 1, 0, 60, 1, 0, 1],
    },
    columns=data_x_numeric.columns,
    orient="index",
)
x_new

In [None]:
set_config(display="text")  # displays text representation of estimators. Needed in some Jupyter environments.

estimator = CoxPHSurvivalAnalysis()  # this is one of the most common models for SA.
estimator.fit(data_x_numeric, data_y)

# make predictions with the new patients.
N: int = 1_000  # observation lenght in days
pred_surv = estimator.predict_survival_function(x_new)
time_points = np.arange(1, N)
for i, surv_func in enumerate(pred_surv):
    plt.step(time_points, surv_func(time_points), where="post", label=f"New patient {i + 1}")
plt.ylabel(r"est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")
plt.show()

In [None]:
# coefficients 

pd.Series(
    estimator.coef_,  # we can access the coefficients as it is a linear model
    index=data_x_numeric.columns
).plot(
    kind="bar",
    title="Feature Contribution",
    ylabel="Contribution"
)
plt.show()

Is the treatment (`Treatment == test`) working?

## Evaluation

Evaluation is not as a traditional classification problem, however, Survival Analysis has a metric very similar to the [ROC score](https://www.evidentlyai.com/classification-metrics/explain-roc-curve#:~:text=The%20ROC%20AUC%20score%20can,inadequate%20for%20any%20real%20applications.), called Concordance Index. Its prinicple is basically the same, higher than 0.5 means better than random, a value of 1 means a perfect classifier, values lower than 0.5 means worse than random.

It's of course, embedded in the Scikit-survival Library

In [None]:
round(estimator.score(data_x_numeric, data_y), 2)
