# Random Forest Survival Analysis. 

Tree based algorithms are also a powerful alternative in survival analysis (SA). As in traditional machine learning, these type of models offer a good balance between robustness and "explicability", which makes them a great option in clinical environments where you need precise predictions in which its drivers can be identified and explained. 

## About the Data

We are going to use Breast Cancer data from a German clinical tria. As described by Scikit-Survival: 
German Breast Cancer Study Group (GBSG-2) on the treatment of node-positive breast cancer patients. It contains data on 686 women and 8 prognostic factors: 1. age, 2. estrogen receptor (estrec), 3. whether or not a hormonal therapy was administered (horTh), 4. menopausal status (menostat), 5. number of positive lymph nodes (pnodes), 6. progesterone receptor (progrec), 7. tumor size (tsize, 8. tumor grade (tgrade).

We are tasked with predicting recurrence-free survival time. In other words, the length of time from the end of primary treatment (such as surgery, radiation, or chemotherapy) until there is evidence of cancer recurrence or until the death of the patient, regardless of the cause

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm


from sklearn import set_config
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder

from sksurv.datasets import load_gbsg2
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest

from time_to_event.utils import dist_histogram, countplot

set_config(display="text")  # displays text representation of estimators

plt.rcParams["figure.figsize"] = (12, 8)
plt.style.use("ggplot")

SEED: int = 123456  # set seed for reproducibility

In [None]:
data_X, y = load_gbsg2()
data_X.head()


In [None]:
y[:5]

##  EDA

In [None]:
eda_df = pd.concat(
   [ 
       data_X,
       pd.Series(y["cens"], name="cens"),
       pd.Series(y["time"], name="time")
    ],
    axis=1
)

In [None]:
eda_df.describe()

In [None]:
countplot(
    df=eda_df,
    x='menostat', 
    hue="cens",
    stat="percent"
)

In [None]:
countplot(
    df=eda_df,
    x='tgrade', 
    hue="cens",
    stat="percent"
)

In [None]:
dist_histogram(eda_df, "age", color="b", hue_="cens")

In [None]:
dist_histogram(eda_df, "estrec", color="b", hue_="cens")

In [None]:
dist_histogram(eda_df, "tsize", color="b", hue_="cens")

## Modelling

In [None]:
# we are going to use one-hot encoding for the non numeric variables.
# however, we want to mantain an order when it's important
ordered_grade = OrdinalEncoder(
    categories=[["I", "II", "III"]]  # lowest to largest
).fit_transform(data_X.tgrade.values.reshape(-1, 1))

x_no_grade = data_X.drop(columns=["tgrade"])
X = OneHotEncoder().fit_transform(x_no_grade)
# re-adding the grade column with proper ordering information.
X["tgrade"] = ordered_grade


In [None]:
TEST_SIZE: float = 0.3

# Split train and test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=TEST_SIZE,
    random_state=SEED
)

In [None]:
# set up model params. I recommend playing with this to have different outcomes
# You can also (and should) perform Cross Validation as you would with Scikit learn traditional ML
model_params: dict = {
    "n_estimators": 1_000,
    "min_samples_split":8,
    "min_samples_leaf":12,
    "n_jobs":-1, 
    "random_state":SEED
}

rsf = RandomSurvivalForest(**model_params)
# Train the model
rsf.fit(X_train, y_train)

### Model Evaluation. 

We are still using the concordance index, as described in the [Introduction](./introduction.ipynb) Notebook

In [None]:
# Evaluation
# Again 
print("Concordance Index:")
print(round(rsf.score(X_test, y_test), 2))

### Inspecting predictions

For simplicity, let's sort the patients by pnodes and age. Pick a _N_ sample of the top and the same for the bottom of the resulting dataset, and compare their predicted risk score. 

In [None]:
N: int = 5

X_test_sorted = X_test.sort_values(by=["pnodes", "age"])
X_test_sel = pd.concat((X_test_sorted.head(N), X_test_sorted.tail(N)))

X_test_sel

In [None]:
pd.Series(rsf.predict(X_test_sel))

The predicted risk scores indicate that risk for the bottom patients is higher than that of those on the top of the sample overall. But what can be the case for the first patient?

In [None]:
surv = rsf.predict_survival_function(X_test_sel, return_array=True)
# Create a colormap that goes from blue to red
cmap = cm.CMRmap_r
colors = cmap(np.linspace(0, 1, len(surv)))

# Plot each line with a color from the colormap
for i, (s, color) in enumerate(zip(surv, colors)):
    plt.step(rsf.unique_times_, s, where="post", label=f"Patient sample {str(i)}", color=color)
plt.ylabel("Survival probability")
plt.xlabel("Time in days")
plt.legend()
plt.grid(True)

Another super useful metric is the **predicted cumulative hazard function**. We can plot it using the Survival package, but firt a brief introduction to what it represents. 


In survival analysis, the predicted cumulative hazard function is a way to describe the risk of an event happening over time.

In simmple terms:

Hazard Function: Think of the hazard function as the risk or rate of an event (like failure, death, or relapse) happening at a specific time. It tells you how risky it is for the event to occur at that precise moment, given that the individual has survived up to that time.

Cumulative Hazard Function: Now, the cumulative hazard function adds up these risks over time. It accumulates the risk from the start of the observation period to a specific point in time.


In mathematical terms, if the hazard function at time 
𝑡
t is 
ℎ
(
𝑡
)
h(t), then the cumulative hazard function 
𝐻
(
𝑡
)
H(t) is the integral of 
ℎ
(
𝑡
)
h(t) from the start time to time 
𝑡
t:

𝐻
(
𝑡
)
=
∫
0
𝑡
ℎ
(
𝑢
)
 
𝑑
𝑢
H(t)=∫ 
0
t
​
 h(u)du

In survival analysis, the cumulative hazard function helps in understanding the total risk accumulated over a period and can be used to estimate the survival function, which tells us the probability of surviving up to a certain time

In [None]:
surv = rsf.predict_cumulative_hazard_function(X_test_sel, return_array=True)

cmap = cm.CMRmap_r
colors = cmap(np.linspace(0, 1, len(surv)))

# Plot each line with a color from the colormap
for i, (s, color) in enumerate(zip(surv, colors)):
    plt.step(rsf.unique_times_, s, where="post", label=f"Patient sample {str(i)}", color=color)
plt.ylabel("Cumulative hazard")
plt.xlabel("Time in days")
plt.legend()
plt.grid(True)

## Feature importance

Similar to the ROC, the feature importance metrics in SA have an equivalent metric called [**Permutation Importance**](https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html#sklearn.inspection.permutation_importance) that allows us to see which features are the main drivers for our predictions. 

In the Scikit-survival library, this is implemented in the permutation_importance function of scikit-learn.

In [None]:
from sklearn.inspection import permutation_importance

N_REPEATS: int = 10

result = permutation_importance(rsf, X_test, y_test, n_repeats=N_REPEATS, random_state=SEED)

In [None]:
importance_df = pd.DataFrame(
    {
        k: result[k]
        for k in (
            "importances_mean",
            "importances_std",
        )
    },
    index=X_test.columns,
).sort_values(by="importances_mean", ascending=False)

importance_df

In [None]:

plt.bar(importance_df.index, importance_df.importances_mean)

c = [1, 3, 2, 1]

plt.errorbar(
    importance_df.index,
    importance_df.importances_mean,
    yerr=importance_df.importances_std, 
    fmt="o", 
    color="grey"
)

plt.title("Feature Importance with error bars")
plt.show()
