In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import pandas as pd

from tvb_inversion.base.diagnostics import (zscore, shrinkage)
from tvb_inversion.sbi.plot import (
    plot_posterior_samples_model_parameters,
    plot_posterior_samples_global_parameters
)

%load_ext autoreload
%autoreload 2

In [None]:
# run_id = "2023-02-08_1831"  # around ground truth
# run_id = "2023-02-08_1859"  # shifted from ground truth
# run_id = "2023-02-09_1401"  # shifted from ground truth
run_id = "2023-02-24_1844"  # around ground truth

posterior_samples = np.load(f"sbi_data/posterior_samples_{run_id}.npy")
prior_samples = np.load(f"sbi_data/prior_samples_{run_id}.npy")
test_samples = np.load(f"sbi_data/test_samples_{run_id}.npy")
X = np.load(f"sbi_data/simulation_{run_id}.npy")
training_sims = np.load(f"sbi_data/training_sims_{run_id}.npy")
test_sims = np.load(f"sbi_data/test_sims_{run_id}.npy")
with open(f"sbi_data/sim_params_{run_id}.json", "r") as f:
    simulation_params = json.load(f)
with open(f"sbi_data/inference_params_{run_id}.json", "r") as f:
    inference_params = json.load(f)
with open(f"sbi_data/summary_{run_id}.json", "r") as f:
    summary = json.load(f)

In [None]:
training_sims.shape, test_sims.shape, prior_samples.shape, test_samples.shape

In [None]:
inference_params

In [None]:
f1 = plt.figure(figsize=(12, 8))
plt.plot(summary["training_log_probs"], label="training")
plt.plot(summary["validation_log_probs"], label="validation")
plt.xlabel("epochs", size=16)
plt.ylabel("log probability", size=16)
plt.tick_params(axis="both", labelsize=16)
plt.legend(fontsize=16)
plt.show()

In [None]:
df = pd.DataFrame(
    data=posterior_samples[:, :, :len(inference_params["model_a"])].reshape(-1, len(inference_params["model_a"])),
    columns=[f"Region {i+1}" for i in range(len(inference_params["model_a"]))]
)

plot_posterior_samples_model_parameters(df, simulation_params["model_a"])

In [None]:
from scipy.stats import spearmanr

posterior_model_a = posterior_samples[:, :, :len(inference_params["model_a"])].reshape(-1, len(inference_params["model_a"])).mean(axis=0)
gt_model_a = np.array([v for v in simulation_params["model_a"]])

spearman_correlation = spearmanr(gt_model_a, posterior_model_a)

print(spearman_correlation.correlation, spearman_correlation.pvalue)

In [None]:
fig_spearman = plt.figure(figsize=(8, 5))
plt.plot(posterior_model_a, gt_model_a, linewidth=0, marker="*", markersize=12, color="blue")
plt.tick_params(axis="both", labelsize=16)

In [None]:
# data = {
#     f"model_a_R{i+1}": posterior_samples[:, :, :len(inference_params["model_a"])].reshape(-1, len(inference_params["model_a"]))[:, i] for i in range(len(inference_params["model_a"]))
# }
# data["coupling_a"] = posterior_samples[:, :, -2].flatten()
# df = pd.DataFrame(data)
#
# # f2 = plt.figure(figsize=(15, 10))
# with sns.plotting_context(rc={"axes.labelsize":20}):
#     ax2 = sns.pairplot(data=df, kind="hist", height=5)  # , y_vars=["coupling_a"], x_vars=[k for k, _ in data.items() if "model_a" in k])
# ax2.tick_params(axis="both", labelsize=20)
# plt.show()

In [None]:
data={
    "coupling_a": posterior_samples[:, :, -2].flatten(),
    "nsig": posterior_samples[:, :, -1].flatten()
}
df = pd.DataFrame(data)

plot_posterior_samples_global_parameters(df, simulation_params)

In [None]:
#true_mean = np.array(
#    [v for v in simulation_params["model_a"]] + [simulation_params["coupling_a"], simulation_params["nsig"]])
#posterior_mean = posterior_samples.reshape((-1, *posterior_samples.shape[2:])).mean(axis=0)
#posterior_std = posterior_samples.reshape((-1, *posterior_samples.shape[2:])).std(axis=0)
#prior_std = prior_samples.std(axis=0)

In [None]:
posterior_mean = posterior_samples.mean(axis=1)
posterior_std = posterior_samples.std(axis=1)
prior_std = np.tile(np.array(
    [v for v in simulation_params["model_a"]] + [simulation_params["coupling_a"], simulation_params["nsig"]]) / 2., (len(test_samples), 1))

In [None]:
posterior_mean.mean(axis=0)

In [None]:
zscores = zscore(test_samples, posterior_mean, posterior_std)
shrinkages = shrinkage(prior_std, posterior_std)

In [None]:
f4 = plt.figure(figsize=(12,8))
for i in range(len(simulation_params["model_a"])):
    plt.plot(shrinkages[:, i], zscores[:, i], color=(0, i / 10.0, 1, 1), linewidth=0, marker=".", markersize=8, alpha=0.1)
    plt.plot(shrinkages[i], zscores[i], color=(0, i / 10.0, 1, 1), linewidth=0, marker="*", markersize=12, label=f"model_a_R{i+1}")
plt.plot(shrinkages[:, -2], zscores[:, -2], color="red", linewidth=0, marker=".", markersize=8, alpha=0.1)
plt.plot(shrinkages.mean(axis=0)[-2], zscores.mean(axis=0)[-2], color="red", linewidth=0, marker="*", markersize=12, label="coupling_a")
plt.plot(shrinkages[:, -1], zscores[:, -1], color="green", linewidth=0, marker=".", markersize=8, alpha=0.1)
plt.plot(shrinkages.mean(axis=0)[-2], zscores.mean(axis=0)[-2], color="green", linewidth=0, marker="*", markersize=12, label="nsig")
plt.legend(fontsize=16)
plt.tick_params(axis="both", labelsize=16)
plt.plot();

In [None]:
delete = False
if delete:
    os.remove(f"sbi_data/posterior_samples_{run_id}.npy")
    os.remove(f"sbi_data/prior_samples_{run_id}.npy")
    os.remove(f"sbi_data/simulation_{run_id}.npy")
    os.remove(f"sbi_data/training_sims_{run_id}.npy")
    os.remove(f"sbi_data/test_sims_{run_id}.npy")
    os.remove(f"sbi_data/sim_params_{run_id}.json")
    os.remove(f"sbi_data/inference_params_{run_id}.json")
    os.remove(f"sbi_data/summary_{run_id}.json")