In [None]:
import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import json
from datetime import timedelta

from tvb_inversion.pymc3.plot import plot_posterior_samples
from tvb_inversion.base.diagnostics import (zscore, shrinkage)

In [None]:
# run_id = "2023-01-11_2055"  # around ground truth
# run_id = "2023-01-12_1752"  # around ground truth
# run_id = "2023-01-13_1919"  # shifted from ground truth
# run_id = "2023-01-17_1710"  # shifted from ground truth
# run_id = "2023-01-18_1755"  # shifted from ground truth
# run_id = "2023-01-24_1712"  # shifted from ground truth
# run_id = "2023-01-25_1332"  # shifted from ground truth
run_id = "2023-02-10_1605"

inference_data = az.from_netcdf(f"pymc3_data/{run_id}_idata.nc")
inference_summary = pd.read_json(f"pymc3_data/{run_id}_isummary.json")
with open(f"pymc3_data/{run_id}_iparams.json", "r") as f:
    inference_params = json.load(f)
    f.close()
with open(f"pymc3_data/{run_id}_sim_params.json", "r") as f:
    sim_params = json.load(f)
    f.close()

In [None]:
inference_params

In [None]:
X = inference_data.observed_data.x_obs.values

In [None]:
chains = len(inference_data.sample_stats.chain)
draws = len(inference_data.sample_stats.draw)

In [None]:
print("#chains: ", chains)
print("#draws: ", draws)
divergent = inference_data.sample_stats["diverging"].values
print("Number of Divergent: %d" % divergent.nonzero()[0].size)
divperc = (divergent.nonzero()[0].size / (chains * draws)) * 100
print("Percentage of Divergent: %.1f" % divperc)
print("Mean tree accept: %.2f" % inference_data.sample_stats.acceptance_rate.values.mean())
print("Mean tree depth: %.2f" % inference_data.sample_stats.tree_depth.values.mean())
print("Sampling time:", str(timedelta(seconds=inference_data.sample_stats.sampling_time)))

In [None]:
inference_summary.loc[
    ["model_a_star", "nsig_star",
     "amplitude_star", "offset_star", "measurement_noise_star"]]

In [None]:
inference_summary.loc[
    ["model_a", "nsig",
     "amplitude", "offset", "measurement_noise"]]

In [None]:
init_params = {
    "model_a": sim_params["model_a"],
    "nsig": sim_params["nsig"],
    "measurement_noise": 0.0,
    "amplitude": 1.0,
    "offset": 0.0
}

plot_posterior_samples(inference_data, init_params)

In [None]:
init_params = {
    "model_a_star": 0.0,
    "nsig_star": 0.0,
    "measurement_noise_star": 0.0,
    "amplitude_star": 0.0,
    "offset_star": 0.0
}

plot_posterior_samples(inference_data, init_params)

In [None]:
posterior_x_obs = inference_data.posterior_predictive.x_obs.values.reshape(
    (chains*draws, *X.shape[:-1]))

In [None]:
f3 = plt.figure(figsize=(10,5))
plt.plot(np.percentile(posterior_x_obs[:, :, 0], [2.5, 97.5], axis=0).T,
              "k", label=r"$V_{95\% PP}(t)$")
plt.plot(X[:, 0, 0], label="V_observed")
plt.legend(fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
waic = az.waic(inference_data)
loo = az.loo(inference_data)
print("WAIC: ", waic.waic)
print("LOO: ", loo.loo)

In [None]:
def get_posterior_mean(idata, params):
    posterior = np.asarray([idata.posterior[param].values.reshape((idata.posterior[param].values.size,)) for param in params])
    return posterior.mean(axis=1)

def get_posterior_std(idata, params):
    posterior = np.asarray([idata.posterior[param].values.reshape((idata.posterior[param].values.size,)) for param in params])
    return posterior.std(axis=1)

In [None]:
# zscores
posterior_mean = get_posterior_mean(inference_data, ["model_a", "nsig"])

posterior_std = get_posterior_std(inference_data, ["model_a", "nsig"])

ground_truth = np.array([sim_params["model_a"], sim_params["nsig"]])

zscores = zscore(ground_truth, posterior_mean, posterior_std)

In [None]:
# shrinkages
posterior_std = get_posterior_std(inference_data, ["model_a_star", "nsig_star"])

prior_std = np.ones((2,))

shrinkages = shrinkage(prior_std, posterior_std)

In [None]:
f4 = plt.figure(figsize=(12,8))
plt.plot(shrinkages[0], zscores[0], color="blue", linewidth=0, marker="*", markersize=12, label="model_a")
plt.plot(shrinkages[1], zscores[1], color="green", linewidth=0, marker="*", markersize=12, label="nsig")
plt.xlabel("posterior shrinkage")
plt.ylabel("posterior zscore")
plt.legend(fontsize=16)
plt.plot();