# Survival Models

Some distributions model events that may happen over time.

This family of models is called **Survival Models** because they were first used in the medical field to estimate patient survival probabilities over time.

The main concept in survival analysis is **data censoring**: a portion of your dataset are going to die, but we haven't observed it yet, because it will happen in the future.

Survival analysis is useful for all sorts of other events, like time to paid conversion for free users on a website, etc.

## Manually Fitting a probability distribution to data

In [None]:
# imports
import numpy as np
import scipy.stats as st
import statsmodels.datasets
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

data = statsmodels.datasets.heart.load_pandas().data
data.sample(5)


In [None]:
# fit probability dist. to data 

data = data[data.censors==1]
survival = data.survival

fig, (ax1,ax2) = plt.subplots(1,2,figsize=(10,4))

ax1.plot(sorted(survival)[::-1],'o')
ax1.set_xlabel('Patient')
ax1.set_ylabel('Survival time (days)')

ax2.hist(survival,bins=15)
ax2.set_xlabel('Survival time (days)')
ax2.set_ylabel('Number of patients')

In [None]:
# 
smean = survival.mean()
rate = 1. / smean

smax = survival.max()
days = np.linspace(0., smax, 1000)

#Bin size: interval between two consecutive values in days
dt = smax /999

dist_exp = st.expon.pdf(days, scale=1. / rate)
nbins = 30

fig, ax = plt.subplots(1,1,figsize=(6,4))

ax.hist(survival,nbins)
ax.plot(days, dist_exp * len(survival) * smax / nbins, '-r', lw=3)

ax.set_xlabel('Survival time (days)')
ax.set_ylabel('Number of patients')


## Statsmodels Survival Regression

Survival regression in statsmodels can come in the [Cox Proportional Hazards model](https://www.statsmodels.org/stable/generated/statsmodels.duration.hazard_regression.PHReg.html#statsmodels.duration.hazard_regression.PHReg), which supports censored data: 

Survival models relate the time that passes, before some event occurs, to one or more predictors that may be associated with that quantity of time. In a proportional hazards model, the unique effect of a unit increase in a predictors is multiplicative with respect to the hazard rate.

Hazard rate might be higher for a 97 year old alcoholic chain smoker than a 21 year old college athlete.
Higher hazard rate will leads to more deaths in less times while a lower one reduces this probability

In [None]:
# statsnodels survival regression
import statsmodels.api as sm 

y = data.survival
X = sm.add_constant(data.age)
censor = data.censors

est = sm.PHReg(y, X, status=censor).fit()
est.summary()

In [None]:
# Statsmodels also supports survival functions with right censoring:

sm.SurvfuncRight(y, status=censor).plot();


## Lifelines Package

A popular package for survival analysis is [lifelines](https://lifelines.readthedocs.io/en/latest/) which you need to install

In [None]:
# Run if necessary
#!pip install lifelines

In [None]:
# lifeline package
from lifelines.datasets import load_waltons
df = load_waltons() # returns a Pandas DataFrame

df.sample(5)

In [None]:
# lifeline package cont.
from lifelines import KaplanMeierFitter

T = df['T']
E = df['E']

kmf = KaplanMeierFitter()
kmf.fit(T,E)
kmf.plot_survival_function();


In [None]:
# lifeline package cont..we can split by group!

groups = df['group']
ix = (groups == 'miR-137') 

kmf.fit(T[~ix],E[~ix],label='control') 
ax = kmf.plot_survival_function()

kmf.fit(T[ix],E[ix],label='miR-137')
ax = kmf.plot_survival_function(ax=ax)


### Survival Regression in Lifelines

In [None]:
from lifelines.datasets import load_regression_dataset
df = load_regression_dataset()
df

In [None]:
# Survival Regression in lifelines
from lifelines.datasets import load_regression_dataset
df = load_regression_dataset()

from lifelines import CoxPHFitter

# Using Cox Proportional Hazards model
cph = CoxPHFitter()
cph.fit(df, 'T', event_col='E')
print(cph.print_summary())
cph.plot();


In [None]:
from lifelines import WeibullAFTFitter 

wft = WeibullAFTFitter()
wft.fit(df, 'T', event_col='E')
wft.print_summary()

wft.plot()