In [None]:
import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
from datetime import timedelta
import pickle
from pymc3.backends.base import MultiTrace

from tvb_inversion.base.diagnostics import (zscore, shrinkage)
from tvb_inversion.pymc3.plot import (
    plot_posterior_samples_model_parameters,
    plot_posterior_samples_global_parameters,
    posterior_pairplot
)

%load_ext autoreload
%autoreload 2

In [None]:
run_id = "2023-03-13_1329"

inference_data = az.from_netcdf(f"pymc3_data/{run_id}-3_idata.nc")
inference_summary = pd.read_json(f"pymc3_data/{run_id}-3_isummary.json")
with open(f"pymc3_data/{run_id}_iparams.json", "r") as f:
    inference_params = json.load(f)
    f.close()
with open(f"pymc3_data/{run_id}_sim_params.json", "r") as f:
    sim_params = json.load(f)
    f.close()

# inference_data = az.from_netcdf(f"pymc3_data/{run_id}.nc")
# inference_summary = pd.read_json(f"pymc3_data/{run_id}.json")

In [None]:
with open(f"pymc3_data/{run_id}_tuning_model.pkl", "rb") as buff:
    data = pickle.load(buff)

model = data["model"]

In [None]:
with model:
    trace = pm.load_trace(".pymc_2.trace")
    tuning_trace = MultiTrace([trace._straces[0][5:], trace._straces[1][5:], trace._straces[2][5:]])

In [None]:
with model:
    tuning_posterior_predictive = pm.sample_posterior_predictive(trace=tuning_trace)
    tuning_inference_data = az.from_pymc3(trace=tuning_trace, posterior_predictive=tuning_posterior_predictive, save_warmup=True)

In [None]:
inference_data1 = az.from_netcdf(f"pymc3_data/{run_id}_idata.nc")
inference_data2 = az.from_netcdf(f"pymc3_data/{run_id}-2_idata.nc")
inference_data3 = az.from_netcdf(f"pymc3_data/{run_id}-3_idata.nc")

In [None]:
cinference_data = az.concat([inference_data2, inference_data3], dim="draw")

In [None]:
cinference_summary = az.summary(cinference_data)

In [None]:
X = inference_data.observed_data.x_obs.values

In [None]:
chains = len(inference_data.sample_stats.chain)
draws = len(inference_data.sample_stats.draw)

In [None]:
print("#chains: ", chains)
print("#draws: ", draws)
divergent = inference_data.sample_stats["diverging"].values
print("Number of Divergent: %d" % divergent.nonzero()[0].size)
divperc = (divergent.nonzero()[0].size / (chains * draws)) * 100
print("Percentage of Divergent: %.1f" % divperc)
print("Mean tree accept: %.2f" % inference_data.sample_stats.acceptance_rate.values.mean())
print("Mean tree depth: %.2f" % inference_data.sample_stats.tree_depth.values.mean())
print("Sampling time:", str(timedelta(seconds=inference_data.sample_stats.sampling_time)))

In [None]:
inference_summary.loc[
    [f"model_a_star[{i}]" for i in range(len(inference_params["model_a"]))] + ["coupling_a_star", "nsig_star",
                                                                               "amplitude_star", "offset_star", "measurement_noise_star"]
    ]

In [None]:
ground_truth = np.array(
    [v for v in sim_params["model_a"]] + [sim_params["coupling_a"], sim_params["nsig"], 1.0, 0.0, 0.0])
inference_summary_ = inference_summary.loc[
    [f"model_a[{i}]" for i in range(len(inference_params["model_a"]))] + ["coupling_a", "nsig",
                                                                          "amplitude", "offset", "measurement_noise"]
    ]
inference_summary_.insert(0, "ground_truth", ground_truth)
inference_summary_

In [None]:
df = pd.DataFrame(
    data=inference_data.posterior.model_a.values.reshape(-1, inference_data.posterior.model_a.values.shape[-1]),
    columns=[f"Region {i+1}" for i in range(len(inference_params["model_a"]))]
)

plot_posterior_samples_model_parameters(df.iloc[:, :10], sim_params["model_a"][:10])

In [None]:
df = pd.DataFrame(
    data=inference_data.posterior.model_a.values.reshape(-1, inference_data.posterior.model_a.values.shape[-1]),
    columns=[f"Region {i+1}" for i in range(len(inference_params["model_a"]))]
)

plot_posterior_samples_model_parameters(df.iloc[:, 10:20], sim_params["model_a"][10:20])

In [None]:
from scipy.stats import spearmanr

posterior_model_a = inference_summary.loc[[f"model_a[{i}]" for i in range(len(inference_params["model_a"]))]][["mean"]].to_numpy()
gt_model_a = np.array([v for v in sim_params["model_a"]])

spearman_correlation = spearmanr(gt_model_a, posterior_model_a)

print(spearman_correlation.correlation, spearman_correlation.pvalue)

In [None]:
fig_spearman = plt.figure(figsize=(8, 5))
plt.plot(posterior_model_a, gt_model_a, linewidth=0, marker="*", markersize=12, color="blue")
plt.xlabel("posterior mean", fontsize=16)
plt.ylabel("ground truth", fontsize=16)
plt.tick_params(axis="both", labelsize=16)
# plt.savefig(save_path + "rank_correlation.png", dpi=600, bbox_inches="tight")

In [None]:
# data = {
#     f"model_a_R{i+1}": inference_data.posterior.model_a.values.reshape(-1, inference_data.posterior.model_a.values.shape[-1])[:, i] for i in range(len(inference_params["model_a"]))
# }
# data["coupling_a"] = inference_data.posterior.coupling_a.values.flatten()
# df = pd.DataFrame(data)
#
# posterior_pairplot(df, size=2)

In [None]:
data={
    "coupling_a": inference_data.posterior.coupling_a.values.flatten(),
    "nsig": inference_data.posterior.nsig.values.flatten(),
    "measurement_noise": inference_data.posterior.measurement_noise.values.flatten(),
}
df = pd.DataFrame(data)

plot_posterior_samples_global_parameters(df, sim_params)

In [None]:
inference_data.posterior_predictive.x_obs.values.shape

In [None]:
posterior_x_obs = inference_data.posterior_predictive.x_obs.values.reshape(
    (1200, *X.shape))

In [None]:
num_regions = 6
f3, axes3 = plt.subplots(nrows=num_regions, ncols=1, figsize=(25, num_regions*12))
for i in range(num_regions):
    ax = axes3[i]

    ax.plot(np.percentile(posterior_x_obs[:, :, 0, i], [2.5, 97.5], axis=0).T,
            "k", label=r"$V_{95\% PP}(t)$")
    ax.plot(X[:, 0, i], label="V_observed")
    ax.set_title(f"Region {i+1}", fontsize=16)
    ax.legend(fontsize=16)
    ax.set_xlabel("time (ms)", fontsize=16)
    ax.tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
waic = az.waic(inference_data)
loo = az.loo(inference_data)
print("WAIC: ", waic.waic)
print("LOO: ", loo.loo)

In [None]:
def get_posterior_mean(idata, params):
    posterior = np.concatenate(
        [idata.posterior[param].values.reshape(-1, idata.posterior[param].values.shape[-1]) if idata.posterior[param].values.ndim == 3
         else idata.posterior[param].values.flatten()[..., np.newaxis] for param in params],
        axis=1
    )
    # posterior = np.asarray([idata.posterior[param].values.reshape((idata.posterior[param].values.size,)) for param in params])
    return posterior.mean(axis=0)

def get_posterior_std(idata, params):
    posterior = np.concatenate(
        [idata.posterior[param].values.reshape(-1, idata.posterior[param].values.shape[-1]) if idata.posterior[param].values.ndim == 3
         else idata.posterior[param].values.flatten()[..., np.newaxis] for param in params],
        axis=1
    )
    # posterior = np.asarray([idata.posterior[param].values.reshape((idata.posterior[param].values.size,)) for param in params])
    return posterior.std(axis=0)

In [None]:
# zscores
posterior_mean = get_posterior_mean(inference_data,
                                    ["model_a", "coupling_a", "nsig"])

posterior_std = get_posterior_std(inference_data,
                                  ["model_a", "coupling_a", "nsig"])

ground_truth = np.array(
    [v for v in sim_params["model_a"]] + [sim_params["coupling_a"], sim_params["nsig"]])

zscores = zscore(ground_truth, posterior_mean, posterior_std)

In [None]:
def zscore_gt(true_mean, post_mean):
    return np.abs((post_mean - true_mean) / true_mean)

def zscore_prior(true_mean, post_mean, prior_std):
    return np.abs((post_mean - true_mean) / prior_std)

In [None]:
# shrinkages
posterior_std = get_posterior_std(inference_data,
    ["model_a_star", "coupling_a_star", "nsig_star"])

prior_std = np.ones((78,))

shrinkages = shrinkage(prior_std, posterior_std)

In [None]:
f4 = plt.figure(figsize=(12,8))
for i in range(len(sim_params["model_a"])):
    plt.plot(shrinkages[i], zscores[i], color="blue", linewidth=0, marker="*", markersize=12, label=f"model_a_R{i+1}")
plt.plot(shrinkages[-2], zscores[-2], color="red", linewidth=0, marker="*", markersize=12, label="coupling_a")
plt.plot(shrinkages[-1], zscores[-1], color="green", linewidth=0, marker="*", markersize=12, label="nsig")
plt.xlabel("posterior shrinkage", size=16)
plt.ylabel("posterior zscore", size=16)
# plt.xlim(xmax=1.05)
# plt.legend(fontsize=16)
plt.tick_params(axis="both", labelsize=16)
plt.plot();

In [None]:
import pickle
from pymc3.sampling import _init_jitter
from pymc3.step_methods.hmc import quadpotential

with open("pymc3_data/2023-03-13_1329_tuning_model.pkl", "rb") as buff:
    data = pickle.load(buff)

model = data["model"]
with model:
    trace = pm.load_trace(".pymc_2.trace")
    cov_est_load = pm.trace_cov(trace)

In [None]:
cov_est_load.shape

In [None]:
chain1 = trace._straces[0][5:].samples
chain2 = trace._straces[1][5:].samples
chain3 = trace._straces[2][5:].samples
chain4 = trace._straces[3][5:].samples

In [None]:
trace._straces[0].vars

In [None]:
from pymc3.backends.base import MultiTrace
rtrace = MultiTrace([trace._straces[0][5:], trace._straces[1][5:], trace._straces[2][5:]])

In [None]:
with model:
    rcov_est = pm.trace_cov(rtrace)

In [None]:
np.diag(rcov_est)

In [None]:
np.diag(cov_est)

In [None]:
plt.hist(chain3["model_a"][:, 0], bins=100);

In [None]:
with model:
    diverging1 = trace._straces[0]._get_sampler_stats(varname="diverging", sampler_idx=0, burn=0, thin=1)
    mean1 = np.asarray([model.dict_to_array(trace._straces[0][i]) for i in np.where(diverging1==False)[0]]).mean(axis=0)
    std1 = np.asarray([model.dict_to_array(trace._straces[0][i]) for i in np.where(diverging1==False)[0]]).std(axis=0)
    var1 = np.asarray([model.dict_to_array(trace._straces[0][i]) for i in np.where(diverging1==False)[0]]).var(axis=0)

    diverging2 = trace._straces[1]._get_sampler_stats(varname="diverging", sampler_idx=0, burn=0, thin=1)
    mean2 = np.asarray([model.dict_to_array(trace._straces[1][i]) for i in np.where(diverging2==False)[0]]).mean(axis=0)
    std2 = np.asarray([model.dict_to_array(trace._straces[1][i]) for i in np.where(diverging2==False)[0]]).std(axis=0)
    var2 = np.asarray([model.dict_to_array(trace._straces[1][i]) for i in np.where(diverging2==False)[0]]).var(axis=0)

    diverging3 = trace._straces[2]._get_sampler_stats(varname="diverging", sampler_idx=0, burn=0, thin=1)
    mean3 = np.asarray([model.dict_to_array(trace._straces[2][i]) for i in np.where(diverging3==False)[0]]).mean(axis=0)
    std3 = np.asarray([model.dict_to_array(trace._straces[2][i]) for i in np.where(diverging3==False)[0]]).std(axis=0)
    var3 = np.asarray([model.dict_to_array(trace._straces[2][i]) for i in np.where(diverging3==False)[0]]).var(axis=0)

    diverging4 = trace._straces[3]._get_sampler_stats(varname="diverging", sampler_idx=0, burn=0, thin=1)
    # mean4 = np.asarray([model.dict_to_array(trace._straces[3][i]) for i in np.where(diverging4==False)[0]]).mean(axis=0)
    # std4 = np.asarray([model.dict_to_array(trace._straces[3][i]) for i in np.where(diverging4==False)[0]]).std(axis=0)
    # var4 = np.asarray([model.dict_to_array(trace._straces[3][i]) for i in np.where(diverging4==False)[0]]).var(axis=0)
    mean4 = np.zeros(45985)
    std4 = np.ones(45985)
    var4 = np.ones(45985)

In [None]:
np.vstack((mean1, mean2, mean3, mean4)).mean(axis=0)

In [None]:
plt.plot(trace.get_sampler_stats("step_size")[np.where(trace.get_sampler_stats("diverging")==False)[0].tolist()])

In [None]:
trace.get_sampler_stats("step_size")[np.where(trace.get_sampler_stats("diverging")==False)[0].tolist()].mean()

In [None]:
start1 = {var: val.mean(axis=0) for var, val in chain1.items()}
start2 = {var: val.mean(axis=0) for var, val in chain2.items()}
start3 = {var: val.mean(axis=0) for var, val in chain3.items()}
start4 = {var: val.mean(axis=0) for var, val in chain4.items()}

In [None]:
start = [start1, start2, start3, start4]

In [None]:
start1

In [None]:
for var, val in chain4.items():
    print(var, val.mean(axis=0).shape)

In [None]:
trace._straces[0].samples["model_a"].shape

In [None]:
trace.get_sampler_stats("diverging").nonzero()[0].size

In [None]:
plt.plot(trace.get_sampler_stats("step_size_bar"))

In [None]:
plt.plot(trace.get_sampler_stats("model_logp"))

In [None]:
trace.get_sampler_stats("mean_tree_accept").mean()

In [None]:
plt.hist(trace["model_a"][800:, 0], bins=100);

In [None]:
[trace._straces[i][-1] for i in trace.chains]

In [None]:
trace._straces

In [None]:
trace._straces[0][-1]

In [None]:
model.dict_to_array(trace._straces[0][-25])

In [None]:
model.array_to_dict(np.asarray([model.dict_to_array(trace._straces[0][i]) for i in range(len(trace[:-25]))]).mean(axis=0))

In [None]:
trace[:-25]["model_a"]

In [None]:
start = _init_jitter(model, 4, 10)

In [None]:
start

In [None]:
np.mean([model.dict_to_array(vals) for vals in start], axis=0).shape

In [None]:

with model:
    cov_est = pm.trace_cov(trace)

In [None]:
[val for _, val in trace._straces.items()]

In [None]:
sorted(trace._straces)

In [None]:
cov_est.shape

In [None]:
with model:
    n = cov_est.shape[0]
    mean = model.dict_to_array(trace[-25])
    var = np.diag(cov_est)
    potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 1)

In [None]:
np.asarray([model.dict_to_array(trace[i]) for i in range(len(trace[:-25]))]).mean(axis=0)

In [None]:
np.asarray([model.dict_to_array(trace[i]) for i in range(len(trace[:]))]).mean(axis=0)

In [None]:
step_scale = trace[:-25].get_sampler_stats("step_size_bar").mean() * (n ** (1/4))

In [None]:
step_scale * ((1/n) ** (1/4))

In [None]:
potential

In [None]:
[val for var, val in trace[-25].items()]

In [None]:
[var for var, val in trace[-25].items()]

In [None]:
trace.get_sampler_stats("diverging").nonzero()

In [None]:
np.diag(cov_est)