In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import pandas as pd

from tvb_inversion.base.diagnostics import (zscore, shrinkage)

In [None]:
# run_id = "2023-02-02_1238"  # around ground truth
# run_id = "2023-02-02_1319"  # around ground truth
# run_id = "2023-02-02_1807"  # around ground truth
# run_id = "2023-02-03_1318"  # around ground truth
# run_id = "2023-02-03_1326"  # around ground truth
# run_id = "2023-02-03_1351"  # shifted from ground truth
# run_id = "2023-02-03_1427"  # shifted from ground truth
run_id = "2023-02-08_1402"  # around ground truth
# run_id = "2023-02-08_1752"  # shifted from ground truth

posterior_samples = np.load(f"sbi_data/posterior_samples_{run_id}.npy")
prior_samples = np.load(f"sbi_data/prior_samples_{run_id}.npy")
test_samples = np.load(f"sbi_data/test_samples_{run_id}.npy")
X = np.load(f"sbi_data/simulation_{run_id}.npy")
training_sims = np.load(f"sbi_data/training_sims_{run_id}.npy")
test_sims = np.load(f"sbi_data/test_sims_{run_id}.npy")
with open(f"sbi_data/sim_params_{run_id}.json", "r") as f:
    simulation_params = json.load(f)
with open(f"sbi_data/inference_params_{run_id}.json", "r") as f:
    inference_params = json.load(f)
with open(f"sbi_data/summary_{run_id}.json", "r") as f:
    summary = json.load(f)

In [None]:
inference_params

In [None]:
training_sims.shape, test_sims.shape, prior_samples.shape, test_samples.shape

In [None]:
f1 = plt.figure(figsize=(12, 8))
plt.plot(summary["training_log_probs"], label="training")
plt.plot(summary["validation_log_probs"], label="validation")
plt.xlabel("epochs", size=16)
plt.ylabel("log probability", size=16)
plt.tick_params(axis="both", labelsize=16)
plt.legend(fontsize=16)
plt.show()

In [None]:
data={
    "model_a": posterior_samples[:, :, 0].flatten(),
    "nsig": posterior_samples[:, :, 1].flatten(),
    # "measurement_noise": posterior_samples[:, :, -1].flatten(),
}
df = pd.DataFrame(data)

f3, axes3 = plt.subplots(ncols=1, nrows=2, figsize=(10, 15))
label = "ground truth"
for i, (key, value) in enumerate(data.items()):
    ax = axes3.reshape(-1)[i]
    sns.violinplot(y=df[key], bw=.1, ax=ax)
    plt.setp(ax.collections, alpha=.5)
    try:
        ax.axhline(simulation_params[key], linestyle="--", linewidth=2, color="black", label=label)
    except KeyError:
        ax.axhline(0.0, linestyle="--", linewidth=2, color="black", label=label)
    if i == 0:
        ax.legend(fontsize=16)
    label = "_nolegend_"
    ax.tick_params(axis="both", labelsize=16)
    ax.set_ylabel(key, size=16)
plt.show()

# num_params = len(inference_params)
# nrows = int(np.ceil(np.sqrt(num_params)))
# ncols = int(np.ceil(num_params / nrows))
#
# fig1, axes1 = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 16))
# for ax in axes1.reshape(-1):
#     ax.set_axis_off()
# for i, (key, value) in enumerate(inference_params.items()):
#     posterior_ = posterior_samples[:, :, i]
#     posterior_ = posterior_.flatten()
#     ax = axes1.reshape(-1)[i]
#     ax.set_axis_on()
#     ax.hist(posterior_, bins=100, alpha=0.5)
#     ax.axvline(simulation_params[key], color="r", label="simulation parameter")
#     ax.axvline(posterior_.mean(), color="k", label="posterior mean")
#     # ax.axvline(prior_samples[:, i].mean(), color="k", linestyle="-.", label="prior mean")
#
#     ax.set_title(key, fontsize=18)
#     ax.tick_params(axis="both", labelsize=16)
# try:
#     axes1[0, 0].legend(fontsize=18)
# except IndexError:
#     axes1[0].legend(fontsize=18)
# plt.show()

In [None]:
#ground_truth = np.asarray([value for _, value in simulation_params.items()])
#posterior_mean = posterior_samples.reshape((-1, *posterior_samples.shape[2:])).mean(axis=0)
#posterior_std = posterior_samples.reshape((-1, *posterior_samples.shape[2:])).std(axis=0)
#prior_std = prior_samples.std(axis=0)

In [None]:
posterior_mean = posterior_samples.mean(axis=1)
posterior_std = posterior_samples.std(axis=1)
prior_std = np.tile(np.asarray([value for _, value in simulation_params.items()]) / 2., (len(test_samples), 1))

In [None]:
posterior_mean.mean(axis=0)

In [None]:
zscores = zscore(test_samples, posterior_mean, posterior_std)
shrinkages = shrinkage(prior_std, posterior_std)

In [None]:
fig2 = plt.figure(figsize=(12,8))
plt.plot(shrinkages[:, 0], zscores[:, 0], color="blue", linewidth=0, marker=".", markersize=8, alpha=0.1)
plt.plot(shrinkages.mean(axis=0)[0], zscores.mean(axis=0)[0], color="blue", linewidth=0, marker="*", markersize=12, label="model_a")
plt.plot(shrinkages[:, 1], zscores[:, 1], color="green", linewidth=0, marker=".", markersize=8, alpha=0.1)
plt.plot(shrinkages.mean(axis=0)[1], zscores.mean(axis=0)[1], color="green", linewidth=0, marker="*", markersize=12, label="nsig")
plt.xlabel("posterior shrinkage", size=16)
plt.ylabel("posterior zscore", size=16)
plt.legend(fontsize=16)
plt.tick_params(axis="both", labelsize=16)
plt.show()

In [None]:
delete = False
if delete:
    os.remove(f"sbi_data/posterior_samples_{run_id}.npy")
    os.remove(f"sbi_data/prior_samples_{run_id}.npy")
    os.remove(f"sbi_data/simulation_{run_id}.npy")
    os.remove(f"sbi_data/training_sims_{run_id}.npy")
    os.remove(f"sbi_data/test_sims_{run_id}.npy")
    os.remove(f"sbi_data/sim_params_{run_id}.json")
    os.remove(f"sbi_data/inference_params_{run_id}.json")
    os.remove(f"sbi_data/summary_{run_id}.json")