In [None]:
import logging
from tempfile import mkdtemp

from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.datasets import load_rossi
from lifelines.utils.sklearn_adapter import sklearn_adapter

from matplotlib import pyplot as plt
import pandas as pd
from scipy.stats import uniform
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils.fixes import loguniform

from harmoniums import SurvivalHarmonium
from harmoniums.utils import reset_random_state
from harmoniums.views import plot

In [None]:
logging.getLogger().setLevel(logging.INFO)
reset_random_state(1234)
# Choose compute budget. 
# This is the amount of hyperparameter configurations to test.
n_iter = 25

# Lifelines example dataset
Dataset of convicts released from the Maryland state prisons: https://lifelines.readthedocs.io/en/latest/lifelines.datasets.html#lifelines.datasets.load_rossi
where the event is the time until the arrest of a convict.

In [None]:
X = load_rossi()
X.head()

The dataset consists of four types of variables:
- Categorical variables: `fin`, `race`, `wexp`, `mar`, `paro`
- Numeric variable: `age`
- Ordinal variable: `prio`
- Time-to-event variable `week` (with event indicator `arrest`).

The ordinal variable `prio` indicates the number of convictions prior to current incarceration.

In [None]:
X['prio'].value_counts()

Because there are limited number of convicts with more than 5 prior convictions, lets group these together and encode the variable using one-hot-encoding.

In [None]:
# Make dummy categories for the number of prior convictions.
prior_dummies = ['prio_0', 'prio_1', 'prio_2', 'prio_3', 'prio_4', 'prio_5', 'prio>5']
X[prior_dummies] = pd.get_dummies(X['prio'].apply(lambda x: '>5' if x > 5 else x ))
X = X.drop('prio', axis=1)

Putting this all together, we have the following set of variables.

In [None]:
categorical_columns = ['fin', 'race', 'wexp', 'mar', 'paro'] + prior_dummies
numeric_columns = ['age']
survival_columns = ['week']
event_columns = ['arrest']

# Standardise the numeric variable `age`.
X[numeric_columns] = StandardScaler().fit_transform(X[numeric_columns])
X.describe()

In [None]:
X_train, X_test = train_test_split(X)

## Kaplan-Meier

In [None]:
kmf = KaplanMeierFitter()
kmf.fit(durations=X['week'], event_observed=X['arrest']).plot()

# Cox regression
Use cross validation to find optimal $L_2$ penalty.

In [None]:
CoxRegression = sklearn_adapter(CoxPHFitter, event_col='arrest')
# CoxRegression is a class like the `LinearRegression` class or `SVC` class in scikit-learn

cph = CoxRegression()

In [None]:
cph_cv = RandomizedSearchCV(
    cph, 
    param_distributions={
    "penalizer": loguniform(1e-5, 1e2),
    "l1_ratio": loguniform(1e-5, 1),
    }, 
    cv=5,
    n_iter=n_iter,
    n_jobs=1,
).fit(X_train.drop('week', axis=1), X_train['week'])

In [None]:
cph = cph_cv.best_estimator_
print(cph)

## Accuracy
Compute Harrell's concordance index (c-index) on the test set.

In [None]:
cph.score(X_test.drop('week', axis=1), X_test['week'])

# Harmonium
In view of the larger number of hyperparameters we will use a randomised grid search (as opposed to exhaustive search for the Cox model). The models are scored using the `score` function, which is just the concordance index.

In [None]:
# Time horizon of the model (model events in the range [0, max_t]).
max_t = X['week'].max(axis=0)

log_dir = mkdtemp()
harmonium = SurvivalHarmonium(
    categorical_columns=categorical_columns,
    survival_columns=survival_columns,
    numeric_columns=numeric_columns,
    event_columns=event_columns,
    verbose=True,
    log_every_n_iterations=300, 
    time_horizon=[max_t],
    # Don't use median value because this value 
    # is not observed in the dataset (see Kaplan-Meier).
    risk_score_time_point=0.75 * max_t,
    # Don't evaluate any metrics for now.
    metrics=('log_likelihood', 'score', 'brier_loss'),
    X_validation=X_test,
    # Log to temporary directory.
    output=log_dir,
    CD_steps=1,
    n_hidden_units=2,
    persistent=True,
)

In [None]:
print(f'To follow training, run:\ntensorboard --logdir {log_dir}')
harm_cv = RandomizedSearchCV(
    harmonium, 
    param_distributions={
        "learning_rate": loguniform(1e-5, 0.1),
        "n_epochs": loguniform(1e1, 2e4),
        "momentum_fraction": uniform(0, 0.9),
        "mini_batch_size": loguniform(25, 1000),
        "weight_decay": loguniform(1e-5, 0.1),
        "persistent": [True, False],
    },
    cv=5,
    n_iter=n_iter,
    n_jobs=-1,
    refit=True,
).fit(X_train)
print(harm_cv.best_params_)

In [None]:
harmonium = harm_cv.best_estimator_

By looking at the weights of the model, we can see what variables contribute to the selection (i.e., activation) of the latent state.

In [None]:
harm_cv.best_params_

In [None]:
plot(harmonium)

## Training progress

In [None]:
score_train = harmonium.get_train_metrics()
score_val = harmonium.get_validation_metrics()

plt.subplot(1,2,1)
plt.plot(score_train['brier_loss'], label='train')
plt.plot(score_val['brier_loss'], label='test')
plt.ylabel('Brier loss')
plt.legend(frameon=False)


plt.subplot(1,2,2)
plt.plot(score_train['score'], label='train')
plt.plot(score_val['score'], label='test')
plt.ylabel("Harrell's concordance index")
plt.legend(frameon=False)
plt.tight_layout()

plt.figure()
plt.plot(score_train['log_likelihood'], label='train')
plt.plot(score_val['log_likelihood'], label='test')
plt.ylabel('log likelihood')
plt.legend(frameon=False)

## Accuracy
What is the concordance on the test set?

In [None]:
harmonium.score(X_test)