In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import pymc3 as pm
import seaborn as sns
import arviz as az

from matplotlib import pyplot as plt
from pymc3.distributions.timeseries import GaussianRandomWalk
from theano import tensor as T

In [None]:
df = pd.read_csv(pm.get_data("mastectomy.csv"))
df.event = df.event.astype(np.int64)
df = df.rename(columns={"metastasized":"metastized"})
df.metastized = (df.metastized == "yes").astype(np.int64)
n_patients = df.shape[0]
patients = np.arange(n_patients)

In [None]:
n_patients

In [None]:
df.event.mean()


In [None]:
df.head()

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

blue, _, red = sns.color_palette()[:3]

ax.hlines(
    patients[df.event.values == 0], 0, df[df.event.values == 0].time, color=blue, label="Censored"
)

ax.hlines(
    patients[df.event.values == 1], 0, df[df.event.values == 1].time, color=red, label="Uncensored"
)

ax.scatter(
    df[df.metastized.values == 1].time,
    patients[df.metastized.values == 1],
    color="k",
    zorder=10,
    label="Metastized",
)

ax.set_xlim(left=0)
ax.set_xlabel("Months since mastectomy")
ax.set_yticks([])
ax.set_ylabel("Subject")

ax.set_ylim(-0.25, n_patients + 0.25)

ax.legend(loc="center right");

In [None]:
interval_length = 3
interval_bounds = np.arange(0, df.time.max() + interval_length + 1, interval_length)
n_intervals = interval_bounds.size - 1
intervals = np.arange(n_intervals)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

ax.hist(
    df[df.event == 1].time.values,
    bins=interval_bounds,
    color=red,
    alpha=0.5,
    lw=0,
    label="Uncensored",
)
ax.hist(
    df[df.event == 0].time.values,
    bins=interval_bounds,
    color=blue,
    alpha=0.5,
    lw=0,
    label="Censored",
)

ax.set_xlim(0, interval_bounds[-1])
ax.set_xlabel("Months since mastectomy")

ax.set_yticks([0, 1, 2, 3])
ax.set_ylabel("Number of observations")

ax.legend();

In [None]:
last_period = np.floor((df.time - 0.01) / interval_length).astype(int)

death = np.zeros((n_patients, n_intervals))
death[patients, last_period] = df.event

In [None]:
exposure = np.greater_equal.outer(df.time.values, interval_bounds[:-1]) * interval_length
exposure[patients, last_period] = df.time - interval_bounds[last_period]

In [None]:
SEED = 644567  # from random.org

In [None]:
n_intervals

In [None]:
with pm.Model() as model:

    lambda0 = pm.Gamma("lambda0", 0.01, 0.01, shape=n_intervals)

    beta = pm.Normal("beta", 0, sigma=1000)

    lambda_ = pm.Deterministic("lambda_", T.outer(T.exp(beta * df.metastized), lambda0))
    mu = pm.Deterministic("mu", exposure * lambda_)

    obs = pm.Poisson("obs", mu, observed=death, )

In [None]:
n_samples = 1000
n_tune = 1000

In [None]:
with model:
    trace = pm.sample(n_samples, tune=n_tune, random_seed=SEED, return_inferencedata=True)

In [None]:
%%time
with model:
    ppc = pm.sample_posterior_predictive(trace,  random_seed=SEED, samples=1000)
    print('Done infering.')
trace = az.concat(trace, az.from_pymc3(posterior_predictive=ppc))

In [None]:
freq0, bins = np.histogram(np.where(death ==1)[1])
l = []
for i in range(100):
    freq, _ = np.histogram(np.where(trace.posterior_predictive.obs.values[0,i,:,:])[1], bins=bins)
    l.append(freq)
l = np.array(l)
plt.plot(bins[:-1] + 5.9/2, np.mean(l, axis=0),)
plt.errorbar(bins[:-1] + 5.9/2, np.mean(l, axis=0),yerr=np.std(l, axis=0))
plt.hist(np.where(death ==1)[1], bins=bins)

In [None]:
base_hazard = trace.posterior["lambda0"].values
exp = np.exp(np.atleast_2d(trace.posterior["beta"].values))
exp = exp.reshape(exp.shape+(1,))
met_hazard = base_hazard * exp

In [None]:
def cum_hazard(hazard):
    print("called cum", hazard.shape)
    return (interval_length * hazard).cumsum(axis=-1)


def survival(hazard):
    return np.exp(-cum_hazard(hazard))

In [None]:
def plot_with_hpd(x, hazard, f, ax, color=None, label=None, alpha=0.05):
    mean = hazard.mean(axis=(0,1))
    mean = f(mean)

    percentiles = 100 * np.array([alpha / 2.0, 1.0 - alpha / 2.0])
    hpd = np.percentile(f(hazard), percentiles, axis=(0,1))


    ax.fill_between(x, hpd[0], hpd[1], color=color, alpha=0.25)
    ax.step(x, mean, color=color, label=label);

In [None]:
fig, (hazard_ax, surv_ax) = plt.subplots(ncols=2, sharex=True, sharey=False, figsize=(16, 6))
plot_with_hpd(
    x=interval_bounds[:-1], hazard=base_hazard, f=cum_hazard, ax=hazard_ax, color=blue, label="Had not metastized"
)
plot_with_hpd(
    interval_bounds[:-1], met_hazard, cum_hazard, hazard_ax, color=red, label="Metastized"
)

hazard_ax.set_xlim(0, df.time.max())
hazard_ax.set_xlabel("Months since mastectomy")

hazard_ax.set_ylabel(r"Cumulative hazard $\Lambda(t)$")

hazard_ax.legend(loc=2)

plot_with_hpd(interval_bounds[:-1], base_hazard, survival, surv_ax, color=blue)
plot_with_hpd(interval_bounds[:-1], met_hazard, survival, surv_ax, color=red)

surv_ax.set_xlim(0, df.time.max())
surv_ax.set_xlabel("Months since mastectomy")

surv_ax.set_ylabel("Survival function $S(t)$")

fig.suptitle("Bayesian survival model");