In [None]:
from lifelines import CoxPHFitter
from matplotlib import pyplot as plt, rcParams
from numba import _helperlib
from numpy import *
import pandas as pd
import seaborn as sns

from harmoniums import SurvivalHarmonium
from harmoniums.datasets import load_blobs
from harmoniums.distributions import truncated_gamma_distribution
from harmoniums.log_metrics import log_weights_gradients_likelihood_callback, plot_to_tensor
from harmoniums.utils import reset_random_state
from harmoniums.views import plot

In [None]:
def to_param(m, v):
    beta = sqrt(1/v + m**2/(4 * v **2)) + m/(2 * v)
    alpha = beta * m + 1
    return alpha, beta

In [None]:
a = (8.12695264839553, 58.232827555487546)
a0, a1 = a
b = (28.50781059358212, 76.31043674065006)
b0, b1 = b

In [None]:
# Plot dataset.
def compute_mode(a0, b0, a1, b1):
    t = linspace(0, 1, 100)
    T1, T2 = meshgrid(t, t)
    P = truncated_gamma_distribution(T1, a0, b0, 1.0) * truncated_gamma_distribution(T2, a1, b1, 1.0)
    return T1, T2, P

# Harmonium

## Fit data

In [None]:
def plot_likelihood_contours(model, X_train):
    """Plot data and likelihood contours of trained model."""
    f = plt.figure(figsize=(4,3))

    contour_kwargs = {
        'linewidths': 2,
        'levels': array([ 1.,  5., 10., 15., 20., 25., 30., 35.]),
    }
    # Compute log-likelihood on the unit [0, 1] x [0, 1] square.
    t = linspace(0, 1, 100)
    T1, T2 = meshgrid(t[1:], t[1:])
    X0_testpoints = pd.DataFrame({'y': zeros(T1.size), 't1': T1.flatten(), 't2': T2.flatten(), 'event_1': ones(T1.size), 'event_2': ones(T1.size)})
    X1_testpoints = pd.DataFrame({'y': ones(T1.size), 't1': T1.flatten(), 't2': T2.flatten(), 'event_1': ones(T1.size), 'event_2': ones(T1.size)})
    P_x0 = exp(model.log_likelihood(X0_testpoints.to_numpy()))
    P_x1 = exp(model.log_likelihood(X1_testpoints.to_numpy()))

    plt.contour(
        T1,
        T2,
        P_x0.reshape(T1.shape),
        colors='tab:red',
        linestyles='dashed',
        **contour_kwargs
    )
    plt.contour(
        T1,
        T2,
        P_x1.reshape(T1.shape),
        colors='tab:blue',
        **contour_kwargs
    )

    df = pd.DataFrame(X_train, columns=['y', 't1', 't2', 'event_1', 'event_2'])
    df['any_censored'] = ~(df['event_1'].astype(bool) &  df['event_2'].astype(bool))
    sns.scatterplot(        
        data=df.iloc[:500],
        x='t1',
        y='t2',
        hue='y',
        style='any_censored',
        palette={0: 'tab:red', 1: 'tab:blue'},
        legend=False,
        zorder=10,
        alpha=0.5,
    )

    plt.xlabel('$t_1$')
    plt.ylabel('$t_2$')
    plt.tight_layout()
    return f

In [None]:
def log_metrics_tensorboard(
    model, X_train, X_validation, step: int, epoch: int, 
):
    """Callback to log metrics and likelihood plots to tensorboard."""
    _, _, metrics = log_weights_gradients_likelihood_callback(model, X_train, X_validation, step, epoch)

    # Log in total 10 plots to tensorboard.
    if epoch % (model.n_epochs // 10) == 0:
        fig = plot_likelihood_contours(model, X_train)
        tf_fig = plot_to_tensor(fig)
        metrics['Likelihood contour'] = tf_fig

    return step, epoch, metrics


In [None]:
reset_random_state(1234)

X = load_blobs(a, b, censor=True)
X_censor = ascontiguousarray(X.to_numpy(), dtype=float64)

In [None]:
def print_random_state():
    np_state = random.get_state()[1][-1]
    print('Numpy random_state', np_state)
    state_ptr = _helperlib.rnd_get_np_state_ptr()
    index, ints = _helperlib.rnd_get_state(state_ptr)
    nb_state = ints[-1]
    print('Numba random state', nb_state)
    return np_state, nb_state

In [None]:
learning_rate = 0.375
n_epochs = 3e5
momentum_fraction = 0.1
mini_batch_size = 1000
weight_decay = 0.0000
persistent = True
CD_steps = 3


# Model parameters to keep fixed.
harm_params = {
    "categorical_columns": [0],
    "survival_columns": [1, 2],
    "event_columns": [3, 4],
    "log_every_n_iterations": 300,
    "time_horizon": [1.0, 1.0],
    "metrics": log_metrics_tensorboard,
    "guess_weights": False,
    "CD_steps": CD_steps,
    "learning_rate": learning_rate,
    "n_epochs": n_epochs,
    "n_hidden_units": 4,
    "weight_decay": weight_decay,
    "persistent": persistent,
    "momentum_fraction": momentum_fraction,
    "mini_batch_size": mini_batch_size,
    'dry_run': False,
    "verbose": False,
    'random_state': 4321,
}

reset_random_state(1234)

harmonium = SurvivalHarmonium(**harm_params).fit(X_censor)

In [None]:
np_state, nb_state = print_random_state()
assert np_state == 3239245755
assert nb_state == 4088525010

## Likelihood of actual solution

In [None]:
harmonium.log_likelihood(X_censor).mean()

In [None]:
plot(harmonium, True, True, True)

In [None]:
rcParams['font.family'] = 'sans-serif'

fig = plt.figure(figsize=(4,3))

t = linspace(0, 1, 100)
p_0 = truncated_gamma_distribution(t, a0, b0, 1.0)
p_1 = truncated_gamma_distribution(t, a1, b1, 1.0)


grid = plt.GridSpec(4, 4, hspace=0.2, wspace=0.2)
main_ax = fig.add_subplot(grid[1:, :-1])
t1_projection = fig.add_subplot(grid[0, :-1], xticklabels=[])
t2_projection = fig.add_subplot(grid[1:, -1], yticklabels=[])

contour_kwargs = {
    'linewidths': 2,
    'levels': array([ 1.,  5., 10., 15., 20., 25., 30., 35.]),
}
c000 = main_ax.contour(*compute_mode(a0, b0, a0, b0), colors='tab:red', linestyles='dashed', **contour_kwargs)
main_ax.text(0.1, 0.075, '$v^{(1)}$', va='center')

t1_projection.plot()

c110 = main_ax.contour(*compute_mode(a1, b1, a1, b1), colors='tab:red', linestyles='dashed', **contour_kwargs)
main_ax.text(0.8, 0.95, '$v^{(2)}$', va='center')

c011 = main_ax.contour(*compute_mode(a0, b0, a1, b1), colors='tab:blue', **contour_kwargs)
main_ax.text(0.1, 0.95, '$v^{(4)}$', va='center')

c101 = main_ax.contour(*compute_mode(a1, b1, a0, b0), colors='tab:blue', **contour_kwargs)
main_ax.text(0.8, 0.075, '$v^{(3)}$', va='center')

main_ax.set_xlim([0,1])
main_ax.set_ylim([0,1])
main_ax.set_xlabel('$t_1$')
main_ax.set_ylabel('$t_2$')

t1_projection.plot(t, p_0, '-', color='tab:blue', linewidth=2)
t1_projection.plot(t, p_0, '--', color='tab:red', linewidth=2)
t1_projection.plot(t, p_1, '-', color='tab:blue', linewidth=2)
t1_projection.plot(t, p_1, '--', color='tab:red', linewidth=2)
t1_projection.set_xlim([0, 1])
t1_projection.set_ylim([0, 6])
t1_projection.set(xticks=[], yticks=[])
t1_projection.set_xlabel('Projection $t_1$')
t1_projection.xaxis.set_label_position("top")


t2_projection.plot(p_0, t, '-', color='tab:blue', linewidth=2)
t2_projection.plot(p_0, t, '--', color='tab:red', linewidth=2)
t2_projection.plot(p_1, t, '-', color='tab:blue', linewidth=2)
t2_projection.plot(p_1, t, '--', color='tab:red', linewidth=2)
t2_projection.set_ylim([0, 1])
t2_projection.set_xlim([0, 6])
t2_projection.set(xticks=[], yticks=[])
t2_projection.set_ylabel('Projection $t_2$')
t2_projection.yaxis.set_label_position("right")

t2_projection.set(xticks=[], yticks=[])

plt.tight_layout()
plt.savefig('figs/modes.png')
plt.savefig('figs/modes.pdf')
plt.savefig('figs/modes.eps')

In [None]:
f = plot_likelihood_contours(harmonium, X_censor)
f.savefig('figs/data_fit.eps')
f.savefig('figs/data_fit.pdf')

In [None]:
cox_t1 = CoxPHFitter()
cox_t1.fit(X[['y', 't1', 'event_1']], duration_col='t1', event_col='event_1')
cox_t1.print_summary() 

In [None]:
cox_t2 = CoxPHFitter()
cox_t2.fit(X[['y', 't2', 'event_2']], duration_col='t2', event_col='event_2')
cox_t2.print_summary() 