In [None]:
import pandas as pd
import pymc3 as pm
import numpy as np
import seaborn as sns
import arviz as az
import matplotlib.pyplot as plt
import spc_os
from spc_vis import my_plot_ppc

RANDOM_SEED = 28101990
import pickle

from math import ceil
import theano.tensor as T
import scipy.stats as st

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black
%load_ext watermark
%watermark -n -u -v -iv -w -m

In [None]:
raw_data_dir = "../data/raw/"
interim_data_dir = "../data/interim/"
processed_data_dir = "../data/processed/"
external_data_dir = "../data/external/"
models_dir = "../models/"

In [None]:
model_path = models_dir + f"tpore_survival_analysis_individual_sim.nc"

In [None]:
infer = True
save_data = True
print(model_path)

# Load data

In [None]:
df = pd.read_csv(f"{processed_data_dir}data.csv").drop("Unnamed: 0", axis=1)

In [None]:
df.Replica = df.Replica.astype("category")

In [None]:
df["Replica_enc"] = df.Replica.cat.codes

In [None]:
category_dic = {i: cat for i, cat in enumerate(np.unique(df["Replica"]))}

In [None]:
category_dic

In [None]:
n_categories = len(category_dic)

In [None]:
dummies = pd.get_dummies(df.Replica, prefix="Replica")

In [None]:
for col in dummies.columns:
    df[col] = dummies[col]

In [None]:
df.tpore = df.tpore * 10
df.tpore = df.tpore.astype(int)

In [None]:
df.head()

## Visualize Data

In [None]:
df["tpore"].groupby(df["Replica"]).describe()

In [None]:
_ = df["tpore"].hist(by=df["Replica"], sharex=True, density=True, bins=10)

In [None]:
_ = df["tpore"].hist(bins=50)

## Visualize Priors

These are the shapes of the priors used.

In [None]:
beta = 1
alpha = 1
d = st.gamma(scale=1 / beta, a=alpha)
x = np.linspace(0, 10, 100)
tau_0_pdf = d.pdf(x)
plt.plot(x, tau_0_pdf, "k-", lw=2)
plt.xlabel("lambda0(t)")

## Prepare data

In [None]:
n_sims = df.shape[0]
sims = np.arange(n_sims)
interval_length = 15  # 1.5 ns
interval_bounds = np.arange(0, df.tpore.max() + interval_length + 1, interval_length)
n_intervals = interval_bounds.size - 1
intervals = np.arange(n_intervals)

In [None]:
last_period = np.floor((df.tpore - 0.01) / interval_length).astype(int)

pore = np.zeros((n_sims, n_intervals))
pore[sims, last_period] = np.ones(n_sims)

In [None]:
exposure = (
    np.greater_equal.outer(df.tpore.values, interval_bounds[:-1]) * interval_length
)
exposure[sims, last_period] = df.tpore - interval_bounds[last_period]

## Run Model

In [None]:
with pm.Model() as model:

    lambda0 = pm.Gamma("lambda0", 5, 1, shape=n_intervals)

    beta = pm.Normal("beta", 0, sigma=100, shape=(n_categories))

    lambda_ = pm.Deterministic(
        "lambda_", T.outer(T.exp(T.dot(beta, dummies.T)), lambda0)
    )
    mu = pm.Deterministic("mu", exposure * lambda_)
    exp_beta = pm.Deterministic("exp_beta", np.exp(beta))

    obs = pm.Poisson(
        "obs",
        mu,
        observed=pore,
    )

In [None]:
pm.model_to_graphviz(model)

In [None]:
%%time
if infer:
    with model:
        trace = pm.sample(1000, tune=1000, random_seed=RANDOM_SEED, return_inferencedata=True, cores=8)
else:
     trace=az.from_netcdf(model_path)

In [None]:
if infer:
    trace.posterior = trace.posterior.reset_index(
        ["beta_dim_0", "exp_beta_dim_0", "lambda0_dim_0"], drop=True
    )
    trace = trace.rename(
        {
            "lambda0_dim_0": "t",
            "beta_dim_0": "Membrane",
            "exp_beta_dim_0": "Membrane",
        }
    )
    trace = trace.assign_coords(
        t=interval_bounds[:-1] / 10,
        Membrane=list(category_dic.values()),
    )

## Posterior and prior predictive and 

In [None]:
%%time
if infer:
    with model:
        ppc = pm.sample_posterior_predictive(trace,  random_seed=RANDOM_SEED, samples=10000)
        print('Done infering.')
    trace = az.concat(trace, az.from_pymc3(posterior_predictive=ppc))

In [None]:
%%time
# Prior predictive not important: uninformative priors
#if infer:
#    with model:
#        prior = pm.sample_prior_predictive(random_seed=RANDOM_SEED, samples=1000)
#        trace.extend(az.from_pymc3(prior=prior))

## Convergences

In [None]:
with az.rc_context(rc={"plot.max_subplots": None}):
    az.plot_trace(trace, var_names=["beta", "lambda0"])

In [None]:
with az.rc_context(rc={"plot.max_subplots": None}):
    az.plot_autocorr(trace, combined=True, var_names=["lambda0", "beta"])

In [None]:
def get_survival_function(trace):
    l = []
    for interval in range(n_intervals - 1):
        l.append(
            np.trapz(
                trace.values[:, :, :, 0 : interval + 1],
                axis=3,
                dx=interval_length,
            )
        )

    l = np.exp(-np.array(l))
    return l

In [None]:
def get_ecdf(data):
    x = np.sort(data)
    n = x.size
    y = np.arange(1, n + 1) / n
    return x, y

In [None]:
def get_hdi(x, axis, alpha=0.06):
    x_mean = np.nanmedian(x, axis=axis)
    percentiles = 100 * np.array([alpha / 2.0, 1.0 - alpha / 2.0])
    hdi = np.nanpercentile(x, percentiles, axis=axis)

    return x_mean, hdi

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))

survival_function = get_survival_function(trace.posterior.lambda_)
# Empyrical CDF data
ax.plot(*get_ecdf(df.tpore / 10), label="obs.")

# Empyrical CDF data-binned
binned_data = np.where(pore[:, :] == 1)[1] * interval_length / 10
ax.plot(*get_ecdf(binned_data), label="obs. binned")

# Plot Posterior Predictive
hdi = get_hdi(survival_function[:, :, :, :], axis=(1, 2, 3))
x = np.arange(n_intervals - 1) * interval_length / 10.0
ax.plot(x, 1 - hdi[0], label="Posterior Predictive Check")
ax.fill_between(x, 1 - hdi[1][0, :], 1 - hdi[1][1, :], alpha=0.1, color="g")
ax.set_xlabel("t-pore (ns)")
ax.set_ylabel("CDF(t-pore)")
ax.set_title("Posterior Predictive Check")
ax.legend()

In [None]:
n_categories = len(category_dic)
n_rows = ceil(n_categories / 4)
fig, ax = plt.subplots(n_rows, 4, figsize=(6 * 4, 4 * n_rows))

ax = ax.flatten()
for i in range(n_categories):
    # Mask by replica type
    mask = df.Replica == category_dic[i]

    survival_function = get_survival_function(trace.posterior.lambda_[:, :, mask, :])
    # Empyrical CDF data
    ax[i].plot(*get_ecdf(df[mask].tpore / 10), label="obs.")

    # Empyrical CDF data-binned
    binned_data = np.where(pore[mask, :] == 1)[1] * interval_length / 10
    ax[i].plot(*get_ecdf(binned_data), label="obs. binned")

    # Plot Posterior Predictive
    hdi = get_hdi(survival_function[:, :, :, :], axis=(1, 2, 3))
    x = np.arange(n_intervals - 1) * interval_length / 10.0
    ax[i].plot(x, 1 - hdi[0], label="Posterior Predictive Check")
    ax[i].fill_between(x, 1 - hdi[1][0, :], 1 - hdi[1][1, :], alpha=0.1, color="g")
    ax[i].set_xlabel("t-pore (ns)")
    ax[i].set_ylabel("CDF(t-pore)")
    ax[i].set_title(f"Posterior Predictive Check {category_dic[i]}")
    ax[i].legend()
fig.tight_layout()

## Analyze

### Plot posterior

In [None]:
variable = "lambda0"
ax = az.plot_forest(trace, var_names=variable, combined=True)
ax[0].set_xlabel("lambda0[t]")

In [None]:
variable = "beta"
ax = az.plot_forest(trace, var_names=variable, combined=True)
ax[0].set_xlabel("beta")

In [None]:
variable = "exp_beta"
ax = az.plot_forest(trace, var_names=variable, combined=True)
ax[0].set_xlabel("exp(beta)")

In [None]:
hdi = az.hdi(trace.posterior, var_names=["exp_beta"])
for i in range(n_categories):
    print(f"{category_dic[i]} {hdi.exp_beta[i,:].values.mean()}")

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 7))
lambda0 = trace.posterior.lambda0.values
beta = trace.posterior.beta.values
y, hdi = get_hdi(lambda0, (0, 1))
x = interval_bounds[:-1] / 10
ax[0].fill_between(x, hdi[0], hdi[1], alpha=0.25, step="pre", color="grey")
ax[0].step(x, y, label="baseline", color="grey")
for i in range(n_categories):
    lam = np.exp(beta[:, :, [i]]) * lambda0
    y, hdi = get_hdi(lam, (0, 1))
    ax[1].fill_between(x, hdi[0], hdi[1], alpha=0.25, step="pre")
    ax[1].step(x, y, label=f"{category_dic[i]}")

ax[0].legend(loc="best")
ax[0].set_ylabel("lambda0")
ax[0].set_xlabel("t (ns)")
ax[1].legend(loc="best")
ax[1].set_ylabel("lambda_i")
ax[1].set_xlabel("t (ns)")

## Save Model?

In [None]:
print(model_path)

In [None]:
if save_data:
    spc_os.remove(model_path)
    trace.to_netcdf(model_path)