In [None]:
import pandas as pd
import pymc3 as pm
import numpy as np
import seaborn as sns
import arviz as az
import matplotlib.pyplot as plt
import spc_os
from spc_vis import my_plot_ppc

RANDOM_SEED = 28101990
import pickle

from math import ceil
import theano.tensor as T
import scipy.stats as st

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black
%load_ext watermark
%watermark -n -u -v -iv -w -m

In [None]:
raw_data_dir = "../data/raw/"
interim_data_dir = "../data/interim/"
processed_data_dir = "../data/processed/"
external_data_dir = "../data/external/"
models_dir = "../models/"

In [None]:
model_path = models_dir + f"tpore_survival_analysis_same_membrane.nc"

In [None]:
infer = False

# Load data

In [None]:
df = pd.read_csv(f"{processed_data_dir}data.csv").drop("Unnamed: 0", axis=1)

In [None]:
df.Replica = df.membrane

In [None]:
df.Replica = df.Replica.astype("category")

In [None]:
df["Replica_enc"] = df.Replica.cat.codes

In [None]:
category_dic = {i: cat for i, cat in enumerate(np.unique(df["Replica"]))}

In [None]:
category_dic

In [None]:
n_categories = len(category_dic)

In [None]:
dummies = pd.get_dummies(df.Replica, prefix="Replica")

In [None]:
for col in dummies.columns:
    df[col] = dummies[col]

In [None]:
df.tpore = df.tpore * 10
df.tpore = df.tpore.astype(int)

In [None]:
df.head()

## Visualize Data

In [None]:
df["tpore"].groupby(df["Replica"]).describe()

In [None]:
_ = df["tpore"].hist(by=df["Replica"], sharex=True, density=True, bins=10)

In [None]:
_ = df["tpore"].hist(bins=50)

## Visualize Priors

These are the shapes of the priors used.

In [None]:
beta = 1
alpha = 1
d = st.gamma(scale=1 / beta, a=alpha)
x = np.linspace(0, 10, 100)
tau_0_pdf = d.pdf(x)
plt.plot(x, tau_0_pdf, "k-", lw=2)
plt.xlabel("lambda0(t)")

## Prepare data

In [None]:
n_sims = df.shape[0]
sims = np.arange(n_sims)
interval_length = 20  # 2 ns
interval_bounds = np.arange(0, df.tpore.max() + interval_length + 1, interval_length)
n_intervals = interval_bounds.size - 1
intervals = np.arange(n_intervals)

In [None]:
last_period = np.floor((df.tpore - 0.01) / interval_length).astype(int)

pore = np.zeros((n_sims, n_intervals))
pore[sims, last_period] = np.ones(n_sims)

In [None]:
exposure = (
    np.greater_equal.outer(df.tpore.values, interval_bounds[:-1]) * interval_length
)
exposure[sims, last_period] = df.tpore - interval_bounds[last_period]

## Run Model

In [None]:
with pm.Model() as model:

    lambda0 = pm.Gamma("lambda0", 1, 1, shape=n_intervals)

    beta = pm.Normal("beta", 0, sigma=10, shape=(n_categories))

    lambda_ = pm.Deterministic(
        "lambda_", T.outer(T.exp(T.dot(beta, dummies.T)), lambda0)
    )
    mu = pm.Deterministic("mu", exposure * lambda_)
    exp_beta = pm.Deterministic("exp_beta", np.exp(beta))

    obs = pm.Poisson(
        "obs",
        mu,
        observed=pore,
    )

In [None]:
pm.model_to_graphviz(model)

In [None]:
%%time
if infer:
    with model:
        trace = pm.sample(1000, tune=1000, random_seed=RANDOM_SEED, return_inferencedata=True, cores=8)
else:
     trace=az.from_netcdf(model_path)

## Posterior and prior predictive and 

In [None]:
%%time
if infer:
    with model:
        ppc = pm.sample_posterior_predictive(trace,  random_seed=RANDOM_SEED, samples=1000)
        print('Done infering.')
    trace = az.concat(trace, az.from_pymc3(posterior_predictive=ppc))

In [None]:
%%time
if infer:
    with model:
        prior = pm.sample_prior_predictive(random_seed=RANDOM_SEED, samples=1000)
        trace.extend(az.from_pymc3(prior=prior))

## Convergences

In [None]:
with az.rc_context(rc={"plot.max_subplots": None}):
    az.plot_trace(trace, var_names=["beta", "lambda0"])

In [None]:
with az.rc_context(rc={"plot.max_subplots": None}):
    az.plot_autocorr(trace, combined=True, var_names=["lambda0", "beta"])

In [None]:
def get_hdi(x, axis, alpha=0.06):
    x_mean = np.nanmean(x, axis=axis)
    percentiles = 100 * np.array([alpha / 2.0, 1.0 - alpha / 2.0])
    hdi = np.nanpercentile(x, percentiles, axis=axis)

    return x_mean, hdi

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
freq0, bins = np.histogram(
    np.where(pore[:, :] == 1)[1], density=True, bins=np.arange(n_intervals)
)
n_draws = trace.posterior_predictive.obs.values.shape[1]
l = []
for draw in range(n_draws):
    freq, _ = np.histogram(
        np.where(trace.posterior_predictive.obs.values[0, draw, :, :])[1],
        bins=bins,
        density=True,
    )
    l.append(freq)
l = np.array(l)
y, hdi = get_hdi(l, 0)
ax.errorbar(
    bins[:-1] * interval_length,
    y / interval_length,
    yerr=hdi / interval_length,
    label="posterior predictive",
    lw=2,
    capsize=6,
)
ax.bar(
    bins[:-1] * interval_length,
    freq0 / interval_length,
    color="C1",
    width=interval_length,
    label="observed binned",
)
sns.kdeplot(
    df.tpore,
    ax=ax,
    # common_norm=True,
    # bw_adjust=0.3,
    color="grey",
    label="data",
    lw=2.5
    # gridsize=10,
)
ax.set_xlabel("t-pore (ns) * 10")
ax.set_ylabel("p(t-pore)")
ax.legend()
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
freq0, bins = np.histogram(
    np.where(pore[:, :] == 1)[1], density=True, bins=np.arange(n_intervals)
)
n_draws = trace.prior_predictive.obs.values.shape[1]
l = []
for draw in range(n_draws):
    freq, _ = np.histogram(
        np.where(trace.prior_predictive.obs.values[0, draw, :, :])[1],
        bins=bins,
        density=True,
    )
    l.append(freq)
l = np.array(l)
ax.errorbar(
    bins[:-1] * interval_length,
    y / interval_length,
    yerr=hdi / interval_length,
    label="prior predictive",
    lw=2,
    capsize=6,
)
ax.bar(
    bins[:-1] * interval_length,
    freq0 / interval_length,
    color="C1",
    width=interval_length,
    label="observed binned",
)
sns.kdeplot(
    df.tpore,
    ax=ax,
    # common_norm=True,
    # bw_adjust=0.3,
    color="grey",
    label="data",
    lw=2.5
    # gridsize=10,
)
ax.set_xlabel("t-pore (ns) * 10")
ax.set_ylabel("p(t-pore)")
ax.legend()
fig.tight_layout()

## Analyze

### Plot posterior

In [None]:
variable = "lambda0"
az.plot_forest(trace, var_names=variable, combined=True)

In [None]:
variable = "beta"
az.plot_forest(trace, var_names=variable, combined=True)

In [None]:
variable = "exp_beta"
az.plot_forest(trace, var_names=variable, combined=True)

In [None]:
trace.posterior.lambda0.mean(dim=["draw", "chain"])

In [None]:
trace.posterior.lambda0.values.shape

In [None]:
trace.posterior.beta.values.shape

In [None]:
(trace.posterior.beta.values[:, :, [0]] * trace.posterior.lambda0.values).shape

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 7))
lambda0 = trace.posterior.lambda0.values
beta = trace.posterior.beta.values
y, hdi = get_hdi(lambda0, (0, 1))
x = interval_bounds[:-1]
ax[0].fill_between(x, hdi[0], hdi[1], alpha=0.25, step="pre")
ax[0].step(x, y, label="baseline")
for i in range(n_categories):
    lam = np.exp(beta[:, :, [i]]) * lambda0
    y, hdi = get_hdi(lam, (0, 1))
    ax[1].fill_between(x, hdi[0], hdi[1], alpha=0.25, step="pre")
    ax[1].step(x, y, label=f"{category_dic[i]}")

for a in ax:
    a.legend(loc="best")
    a.set_ylabel("p(lambda0)")
    a.set_xlabel("lambda (ns*10)^-1")

To DO
+ modify parameters using copies of notebooks: nbins, priors
+ more categories
+ t-dependent B

## Model?

In [None]:
print(model_path)

In [None]:
spc_os.remove(model_path)
trace.to_netcdf(model_path)