# Final Figures

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import arviz as az
import numpy as np
import pandas as pd

In [None]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black
%load_ext watermark
%watermark -n -u -v -iv -w -m

## RC values for final figures

In [None]:
final_fig_decorator = mpl.rc_context(
    {
        "text.latex.preamble": [
            r"\usepackage{siunitx}",
            r"\usepackage{mhchem}",
        ],
        "axes.labelsize": "20",
        "xtick.labelsize": "15",
        "ytick.labelsize": "20",
        "axes.titlesize": "15",
        # "font.family": "serif",
        # "font.serif": ["Times"],
        # "text.usetex": True,
    }
)

In [None]:
def load_trace(model_path, url_data):
    try:
        trace = az.from_netcdf(model_path)
    except:
        print("Need to download model from OSF.")
        import urllib.request

        model_name = model_path.split("/")[-1]
        urllib.request.urlretrieve(url_data, model_name)
        trace = az.from_netcdf(model_name)
    return trace

## Load data and initial values

In [None]:
processed_data_dir = "../data/processed/"

In [None]:
df = pd.read_csv(f"{processed_data_dir}data.csv").drop("Unnamed: 0", axis=1)

In [None]:
df.Replica = df.membrane

In [None]:
df.Replica = df.Replica.astype("category")

In [None]:
df["Replica_enc"] = df.Replica.cat.codes

In [None]:
category_dic = {i: cat for i, cat in enumerate(np.unique(df["Replica"]))}

In [None]:
category_dic

In [None]:
n_categories = len(category_dic)

In [None]:
dummies = pd.get_dummies(df.Replica, prefix="Replica")

In [None]:
for col in dummies.columns:
    df[col] = dummies[col]

In [None]:
df.tpore = df.tpore * 10
df.tpore = df.tpore.astype(int)

In [None]:
n_sims = df.shape[0]
sims = np.arange(n_sims)
interval_length = 15  # 1.5 ns
interval_bounds = np.arange(0, df.tpore.max() + interval_length + 1, interval_length)
n_intervals = interval_bounds.size - 1
intervals = np.arange(n_intervals)

In [None]:
last_period = np.floor((df.tpore - 0.01) / interval_length).astype(int)

pore = np.zeros((n_sims, n_intervals))
pore[sims, last_period] = np.ones(n_sims)

In [None]:
exposure = (
    np.greater_equal.outer(df.tpore.values, interval_bounds[:-1]) * interval_length
)
exposure[sims, last_period] = df.tpore - interval_bounds[last_period]

## Plotting Functions

In [None]:
def get_survival_function_t_dep(trace):
    n_intervals = trace.shape[-1]
    l = []
    for interval in range(n_intervals - 1):
        l.append(
            np.trapz(
                trace.values[:, :, :, 0 : interval + 1],
                axis=3,
                dx=interval_length,
            )
        )

    l = np.exp(-np.array(l))
    return l

In [None]:
def get_ecdf(data):
    x = np.sort(data)
    n = x.size
    y = np.arange(1, n + 1) / n
    x = np.insert(x, 0, 0.0, axis=0)
    y = np.insert(y, 0, 0.0, axis=0)
    return x, y

In [None]:
def get_hdi(x, axis, alpha=0.06):
    x_mean = np.nanmedian(x, axis=axis)
    percentiles = 100 * np.array([alpha / 2.0, 1.0 - alpha / 2.0])
    hdi = np.nanpercentile(x, percentiles, axis=axis)

    return x_mean, hdi

In [None]:
def get_survival_function(trace):
    l = []
    for interval in range(n_intervals - 1):
        l.append(
            np.trapz(
                trace.values[:, :, 0 : interval + 1, :],
                axis=2,
                dx=interval_length,
            )
        )

    l = np.exp(-np.array(l))
    return l

In [None]:
@final_fig_decorator
def plot_posterior_exp_beta(trace, path_out):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    variable = "expbeta"
    az.plot_forest(trace, var_names=variable, combined=True, ax=ax)
    ax.set_xlabel(r"$\exp\left(\beta\right)$")
    ticks = ax.get_yticklabels()
    new_tick = ticks[-1].get_text().split()[-1]
    ticks[-1].set_text(new_tick)
    ax.set_yticklabels(ticks)
    ax.set_title(r"94% Credible Intervals")
    fig.tight_layout()
    fig.savefig(path_out)
    return fig, ax

In [None]:
@final_fig_decorator
def plot_ppc(trace, interval_length, title, path_out, t_dep_beta=False):
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    survival_function = (
        get_survival_function(trace.posterior.lambda_.astype(np.float16))
        if t_dep_beta
        else get_survival_function_t_dep(trace.posterior.lambda_.astype(np.float16))
    )
    # Empyrical CDF data
    ax.plot(*get_ecdf(df.tpore / 10), label="data")

    # Empyrical CDF data-binned
    binned_data = np.where(pore[:, :] == 1)[1] * interval_length / 10
    ax.plot(*get_ecdf(binned_data), label="data binned")

    # Plot Posterior Predictive
    hdi = get_hdi(survival_function[:, :, :, :], axis=(1, 2, 3))
    x = np.arange(n_intervals - 1) * interval_length / 10.0
    ax.plot(x, 1 - hdi[0], label="Posterior Predictive Check")
    ax.fill_between(x, 1 - hdi[1][0, :], 1 - hdi[1][1, :], alpha=0.1, color="g")
    ax.set_xlabel("$t_{pore}$ (ns)")
    ax.set_ylabel("CDF(t)")
    ax.set_title(title)
    ax.legend()
    fig.tight_layout()
    fig.savefig(path_out)
    return fig, ax

In [None]:
@final_fig_decorator
def plot_posterior_lambda0(trace, path_out, title="", ylim=None):
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    lambda0 = trace.posterior.lambda0.values
    y, hdi = get_hdi(lambda0, (0, 1))
    x = interval_bounds[:-1] / 10
    ax.fill_between(x, hdi[0], hdi[1], alpha=0.25, step="post", color="grey")
    ax.step(x, y, label="baseline", color="grey", where="post")
    ax.set_ylabel(r"$\lambda_0(t)$ ns$^{-1}$")
    ax.set_xlabel("$t$ (ns)")
    ax.set_title(title)
    if ylim is not None:
        ax.set_ylim([0, 12])
    fig.tight_layout()
    fig.savefig(path_out)
    return fig, ax

In [None]:
@final_fig_decorator
def plot_posterior_beta_of_t(trace, category_dic, path_out):
    fig, ax = plt.subplots(1, 1, figsize=(9, 6))

    n_categories = len(category_dic)
    betas = trace.posterior.beta.values
    for i in range(n_categories):
        # Mask by replica type
        y, hdi = get_hdi(betas[:, :, :, i], axis=(0, 1))
        x = np.arange(n_intervals) * interval_length / 10.0
        x = interval_bounds[:-1] / 10
        ax.step(x, y, where="post", label=f"Beta {category_dic[i]}")
        ax.fill_between(
            x,
            hdi[0],
            hdi[1],
            step="post",
            alpha=0.1,
        )
    ax.set_xlabel(r"$t_{pore}$ (ns)")
    ax.set_ylabel(r"$\beta (t)$")
    ax.set_title(r"Time dependent $\beta$ model")
    ax.legend()
    fig.tight_layout()
    fig.savefig(path_out)
    return fig, ax

In [None]:
@final_fig_decorator
def plot_kde_t_pore(df0, path_out):
    from seaborn import kdeplot

    fig, ax = plt.subplots(1, 1, figsize=(9, 6))
    df0.tpore = df0.tpore / 10
    kdeplot(data=df0, x="tpore", hue="Replica", ax=ax, lw=3)
    legend = ax.get_legend()
    legend.set_title("")
    ax.set_xlabel(r"$t_{pore}$ (ns)")
    ax.set_ylabel(r"$p(t_{pore})$")
    ax.set_yticklabels([])
    ax.set_yticks([])
    ax.set_xlim(left=0, right=20)
    fig.tight_layout()
    fig.savefig(path_out)
    return fig, ax

## Posterior Predictive Checks

In [None]:
model_path = "../models/tpore_survival_analysis_same_membrane.nc"
url_data = "https://osf.io/pgjtm/download"
trace = load_trace(model_path, url_data)
fig, ax = plot_ppc(
    trace,
    interval_length,
    "Combining same membrane and same field",
    "../reports/final_figures/ppc_same_membrane.svg",
)

In [None]:
model_path = "../models/tpore_survival_analysis_individual_sim.nc"
url_data = "https://osf.io/rkc97/download"
trace = load_trace(model_path, url_data)
fig, ax = plot_ppc(
    trace,
    interval_length,
    "Separating individual simulations",
    "../reports/final_figures/ppc_individual_sim.svg",
)

In [None]:
model_path = "../models/tpore_survival_analysis_time_dep_beta_same_membrane.nc"
url_data = "https://osf.io/yh6fw/download"
trace = load_trace(model_path, url_data)
fig, ax = plot_ppc(
    trace,
    interval_length,
    r"Time dependent $\beta$",
    "../reports/final_figures/ppc_time_dependent_beta.svg",
    t_dep_beta=True,
)

## Posterior Beta Values

In [None]:
model_path = "../models/tpore_survival_analysis_same_membrane.nc"
url_data = "https://osf.io/pgjtm/download"
trace = load_trace(model_path, url_data)
trace.posterior = trace.posterior.rename({"exp_beta": "expbeta"})
fig, ax = plot_posterior_exp_beta(
    trace, "../reports/final_figures/posterior_exp_beta_same_membrane.svg"
)

In [None]:
model_path = "../models/tpore_survival_analysis_individual_sim.nc"
url_data = "https://osf.io/rkc97/download"
trace = load_trace(model_path, url_data)
trace.posterior = trace.posterior.rename({"exp_beta": "expbeta"})

# If rendering with latex need to escape underscores.

# change_underscores = lambda s: s.replace("_", "\_")
# change_underscores = np.vectorize(change_underscores)
# trace = trace.assign_coords(
#    Membrane=change_underscores(trace.posterior.Membrane.values)
# )
fig, ax = plot_posterior_exp_beta(
    trace, "../reports/final_figures/posterior_exp_beta_individual_sim.svg"
)

## Plotting lambda posterior

In [None]:
model_path = "../models/tpore_survival_analysis_same_membrane.nc"
url_data = "https://osf.io/pgjtm/download"
trace = load_trace(model_path, url_data)
trace.posterior = trace.posterior.rename({"exp_beta": "expbeta"})
fig, ax = plot_posterior_lambda0(
    trace, "../reports/final_figures/posterior_lambda0_same_membrane.svg"
)
ax.set_ylim([0, 12])

## Plotting lambda posterior time dependent model

In [None]:
model_path = "../models/tpore_survival_analysis_time_dep_beta_same_membrane.nc"
url_data = "https://osf.io/yh6fw/download"
trace = load_trace(model_path, url_data)
trace.posterior = trace.posterior.rename({"exp_beta": "expbeta"})
trace.posterior["lambda0"] = trace.posterior.lambda0.squeeze("lambda0_dim_0", drop=True)
fig, ax = plot_posterior_lambda0(
    trace,
    "../reports/final_figures/posterior_lambda0_time_dep_beta.svg",
    title=r"Time dependent $\beta$ model",
    ylim=[0, 12],
)

## Beta time dependent

In [None]:
model_path = "../models/tpore_survival_analysis_time_dep_beta_same_membrane.nc"
url_data = "https://osf.io/yh6fw/download"
trace = load_trace(model_path, url_data)
fig, ax = plot_posterior_beta_of_t(
    trace,
    category_dic,
    "../reports/final_figures/posterior_beta_of_t.svg",
)

## Histograms of poration times

In [None]:
_ = plot_kde_t_pore(df.copy(), "../reports/final_figures/t_pore_distribution.svg")