In [None]:
import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
from datetime import timedelta

from tvb_inversion.pymc3.plot import plot_posterior_samples
from tvb_inversion.base.diagnostics import (zscore, shrinkage)

%load_ext autoreload
%autoreload 2

In [None]:
# run_id = "2023-01-25_1608"  # around ground truth
# run_id = "2023-01-26_0951"  # around ground truth
# run_id = "2023-01-28_1535"  # around ground truth
# run_id = "2023-02-01_1200"  # around ground truth
run_id = "2023-02-02_1233"

inference_data = az.from_netcdf(f"pymc3_data/{run_id}_idata.nc")
inference_summary = pd.read_json(f"pymc3_data/{run_id}_isummary.json")
with open(f"pymc3_data/{run_id}_iparams.json", "r") as f:
    inference_params = json.load(f)
    f.close()
with open(f"pymc3_data/{run_id}_sim_params.json", "r") as f:
    sim_params = json.load(f)
    f.close()

# inference_data = az.from_netcdf(f"pymc3_data/{run_id}.nc")
# inference_summary = pd.read_json(f"pymc3_data/{run_id}.json")

In [None]:
inference_params

In [None]:
X = inference_data.observed_data.x_obs.values

In [None]:
chains = len(inference_data.sample_stats.chain)
draws = len(inference_data.sample_stats.draw)

In [None]:
print("#chains: ", chains)
print("#draws: ", draws)
divergent = inference_data.sample_stats["diverging"].values
print("Number of Divergent: %d" % divergent.nonzero()[0].size)
divperc = (divergent.nonzero()[0].size / (chains * draws)) * 100
print("Percentage of Divergent: %.1f" % divperc)
print("Mean tree accept: %.2f" % inference_data.sample_stats.acceptance_rate.values.mean())
print("Mean tree depth: %.2f" % inference_data.sample_stats.tree_depth.values.mean())
print("Sampling time:", str(timedelta(seconds=inference_data.sample_stats.sampling_time)))

In [None]:
inference_summary.loc[
    [f"model_a_star[{i}]" for i in range(len(inference_params["model_a"]))] + ["coupling_a_star", "nsig_star",
                                                                               "amplitude_star", "offset_star", "measurement_noise_star"]
    ]

In [None]:
ground_truth = np.array(
    [v for v in sim_params["model_a"]] + [sim_params["coupling_a"], sim_params["nsig"], 1.0, 0.0, 0.0])
inference_summary_ = inference_summary.loc[
    [f"model_a[{i}]" for i in range(len(inference_params["model_a"]))] + ["coupling_a", "nsig",
                                                                          "amplitude", "offset", "measurement_noise"]
    ]
inference_summary_.insert(0, "ground_truth", ground_truth)
inference_summary_

In [None]:
df = pd.DataFrame(
    data=inference_data.posterior.model_a.values.reshape(-1, inference_data.posterior.model_a.values.shape[-1]),
    columns=[f"Region {i+1}" for i in range(len(inference_params["model_a"]))]
)

f1 = plt.figure(figsize=(15, 10))
ax1 = sns.violinplot(data=df, bw=.1)
plt.setp(ax1.collections, alpha=.5)
label = "ground truth"
for i, v in enumerate(sim_params["model_a"]):
    plt.axhline(v, xmin=i*(1/len(df.columns)),  xmax=(i+1)*(1/len(df.columns)), linestyle="--", linewidth=2, color="black", label=label)
    label="_nolegend_"
plt.title("model_a", fontsize=16)
plt.tick_params(axis="both", labelsize=16)
plt.legend(fontsize=16)
plt.show()

In [None]:
data = {
    f"model_a_R{i+1}": inference_data.posterior.model_a.values.reshape(-1, inference_data.posterior.model_a.values.shape[-1])[:, i] for i in range(len(inference_params["model_a"]))
}
data["coupling_a"] = inference_data.posterior.coupling_a.values.flatten()
df = pd.DataFrame(data)

#f2 = plt.figure(figsize=(15, 10))
with sns.plotting_context(rc={"axes.labelsize":20}):
    ax2 = sns.pairplot(data=df, kind="reg", x_vars=["coupling_a"], y_vars=[k for k, _ in data.items() if "model_a" in k], height=5)
ax2.tick_params(axis="both", labelsize=20)
plt.show()

In [None]:
# init_params = {f"model_a[{i}]": sim_params["model_a"][i] for i in range(len(sim_params["model_a"]))}
init_params = {}
init_params["coupling_a"] = sim_params["coupling_a"]
init_params["nsig"] = sim_params["nsig"]
init_params["measurement_noise"] = 0.0
init_params["amplitude"] = 1.0,
init_params["offset"] = 0.0

plot_posterior_samples(inference_data, init_params)

In [None]:
# init_params = {f"model_a_star[{i}]": 0.0 for i in range(len(sim_params["model_a"]))}
init_params = {}
init_params["coupling_a_star"] = 0.0
init_params["nsig_star"] = 0.0
init_params["measurement_noise_star"] = 0.0
init_params["amplitude_star"] = 0.0,
init_params["offset_star"] = 0.0

plot_posterior_samples(inference_data, init_params)

In [None]:
posterior_x_obs = inference_data.posterior_predictive.x_obs.values.reshape(
    (chains*draws, *X.shape))

In [None]:
num_regions = 6
f3, axes3 = plt.subplots(nrows=num_regions, ncols=1, figsize=(25, num_regions*12))
for i in range(num_regions):
    ax = axes3[i]

    ax.plot(np.percentile(posterior_x_obs[:, :, 0, i], [2.5, 97.5], axis=0).T,
            "k", label=r"$V_{95\% PP}(t)$")
    ax.plot(X[:, 0, i], label="V_observed")
    ax.set_title(f"Region {i+1}", fontsize=16)
    ax.legend(fontsize=16)
    ax.set_xlabel("time (ms)", fontsize=16)
    ax.tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
waic = az.waic(inference_data)
loo = az.loo(inference_data)
print("WAIC: ", waic.waic)
print("LOO: ", loo.loo)

In [None]:
def get_posterior_mean(idata, params):
    posterior = np.concatenate(
        [idata.posterior[param].values.reshape(-1, idata.posterior[param].values.shape[-1]) if idata.posterior[param].values.ndim == 3
         else idata.posterior[param].values.flatten()[..., np.newaxis] for param in params],
        axis=1
    )
    # posterior = np.asarray([idata.posterior[param].values.reshape((idata.posterior[param].values.size,)) for param in params])
    return posterior.mean(axis=0)

def get_posterior_std(idata, params):
    posterior = np.concatenate(
        [idata.posterior[param].values.reshape(-1, idata.posterior[param].values.shape[-1]) if idata.posterior[param].values.ndim == 3
         else idata.posterior[param].values.flatten()[..., np.newaxis] for param in params],
        axis=1
    )
    # posterior = np.asarray([idata.posterior[param].values.reshape((idata.posterior[param].values.size,)) for param in params])
    return posterior.std(axis=0)

In [None]:
# zscores
posterior_mean = get_posterior_mean(inference_data,
                                    ["model_a", "coupling_a", "nsig"])

posterior_std = get_posterior_std(inference_data,
                                  ["model_a", "coupling_a", "nsig"])

ground_truth = np.array(
    [v for v in sim_params["model_a"]] + [sim_params["coupling_a"], sim_params["nsig"]])

zscores = zscore(ground_truth, posterior_mean, posterior_std)

In [None]:
# shrinkages
posterior_std = get_posterior_std(inference_data,
    ["model_a_star", "coupling_a_star", "nsig_star"])

prior_std = np.ones((12,))

shrinkages = shrinkage(prior_std, posterior_std)

In [None]:
f4 = plt.figure(figsize=(12,8))
for i in range(len(sim_params["model_a"])):
    plt.plot(shrinkages[i], zscores[i], color=(0, i / 10.0, 1, 1), linewidth=0, marker="*", markersize=12, label=f"model_a_R{i+1}")
plt.plot(shrinkages[-2], zscores[-2], color="red", linewidth=0, marker="*", markersize=12, label="coupling_a")
plt.plot(shrinkages[-1], zscores[-1], color="green", linewidth=0, marker="*", markersize=12, label="nsig")
plt.xlabel("posterior shrinkage", size=16)
plt.ylabel("posterior zscore", size=16)
plt.xlim(xmax=1.05)
plt.legend(fontsize=16)
plt.tick_params(axis="both", labelsize=16)
plt.plot();