In [None]:
import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pyt
import arviz as az
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

from tvb.simulator.lab import *
from tvb.simulator.backend.pytensor import PytensorBackend

from tvb_inversion.base.observation_models import linear
from tvb_inversion.pymc.prior import PymcPrior
from tvb_inversion.pymc.stats_model import PymcModel
from tvb_inversion.pymc.inference import EstimatorPYMC
from tvb_inversion.pymc.plot import plot_posterior_samples

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
conn = connectivity.Connectivity()
conn.weights = np.array([[0., 1.], [1., 0.]])
conn.region_labels = np.array(["R1", "R2"])
conn.centres = np.random.rand(2, 3)
conn.tract_lengths = np.array([[0., 2.], [2., 0.]])
conn.configure()

sim = simulator.Simulator(
    model=models.oscillator.Generic2dOscillator(a=np.array([0.75, 2.25])),
    connectivity=conn,
    coupling=coupling.Linear(),
    integrator=integrators.EulerStochastic(
        dt=1.0,
        noise=noise.Additive(
            nsig=np.array([1e-4]),
            noise_seed=42
        )
    ),
    monitors=[monitors.Raw()],
    simulation_length = 250
)

In [None]:
sim.configure();

In [None]:
sim.initial_conditions = np.zeros((conn.horizon, sim.model.nvar, conn.number_of_regions, 1))

In [None]:
sim.configure()

In [None]:
(t, X), = sim.run()

In [None]:
X.shape

In [None]:
f1 = plt.figure(figsize=(10,6))
plt.plot(X[:, 0, 0, 0], label="Region 1")
plt.plot(X[:, 0, 1, 0], label="Region 2")
plt.ylabel("V (a.u.)", fontsize=16)
plt.xlabel("time (ms)", fontsize=16)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.show()

In [None]:
# Create pytensor backend function
template = '<%include file="pytensor-sim.py.mako"/>'
content = dict(sim=sim, mparams=["a"], cparams=["a"], np=np, pyt=pyt)
kernel, default_noise = PytensorBackend().build_py_func(template, content, name="kernel,default_noise", print_source=True)

In [None]:
def_std = 0.5

model = pm.Model()
with model:
    model_a_star = pm.Normal(name="model_a_star", mu=0.0, sigma=1.0, shape=sim.model.a.shape)
    model_a = pm.Deterministic(name="model_a", var=sim.model.a + 0.75 * model_a_star)
    
    coupling_a_star = pm.Normal(name="coupling_a_star", mu=0.0, sigma=1.0)
    coupling_a = pm.Deterministic(name="coupling_a", 
                                  var=sim.coupling.a[0].item() * (1.0 + def_std * coupling_a_star))
    
    x_init_star = pm.Normal(name="x_init_star", mu=0.0, sigma=1.0, 
                            shape=sim.initial_conditions.shape[:-1])
    x_init = pm.Deterministic(name="x_init", 
                              var=sim.initial_conditions[:, :, :, 0] * (1.0 + def_std * x_init_star))
    
    nsig_star = pm.Normal(name="nsig_star", mu=0.0, sigma=1.0)
    nsig = pm.Deterministic(name="nsig", 
                            var=sim.integrator.noise.nsig[0].item() * (1.0 + def_std * nsig_star))
    
    dWt_star = pm.Normal(name="dWt_star", mu=0.0, sigma=1.0, 
                         shape=(X.shape[0], sim.model.nvar, sim.connectivity.number_of_regions))
    dWt = pm.Deterministic(
            name="dWt", var=pyt.sqrt(2.0 * nsig * sim.integrator.dt) * dWt_star)
    
    amplitude_star = pm.Normal(name="amplitude_star", mu=0.0, sigma=1.0)
    amplitude = pm.Deterministic(name="amplitude", var=1.0 * (1.0 + def_std * amplitude_star))

    offset_star = pm.Normal(name="offset_star", mu=0.0, sigma=1.0)
    offset = pm.Deterministic(name="offset", var=def_std * offset_star)
    
    observation_noise_star = pm.HalfNormal(name="observation_noise_star", sigma=1.0)
    observation_noise = pm.Deterministic(name="observation_noise", var=def_std * observation_noise_star)
    

In [None]:
prior = PymcPrior(
    model=model,
    names=["model.a", "coupling.a", "x_init", "integrator.noise.nsig", "dWt_star", 
           "observation.amplitude", "observation.offset", "observation_noise"], 
    dist=[model_a, coupling_a, x_init, nsig, dWt_star, 
          amplitude, offset, observation_noise]
)

In [None]:
pymc_model = PymcModel(sim=sim, params=prior)

In [None]:
with pymc_model.model:

    x_sim = kernel(
        state=x_init,
        weights=sim.connectivity.weights,
        trace=pyt.zeros((len(t),) + sim.initial_conditions[:, :, :, 0].shape),
        parmat=sim.model.spatial_parameter_matrix,
        noise=dWt,
        idelays=sim.connectivity.delay_indices,
        mparams=prior.get_model_params(),
        cparams=prior.get_coupling_params()
    )

In [None]:
with pymc_model.model:
    x_hat = pm.Deterministic(name="x_hat",
                             var=linear(x_sim, **prior.get_observation_model_params()))
    
    x_obs = pm.Normal(
        name="x_obs", mu=x_hat[:, sim.model.cvar, 0, :], sigma=prior.dict.get("observation_noise", 1.0), 
        shape=X.shape[:-1], observed=X[:, :, :, 0])

In [None]:
pymc_model.model

In [None]:
pymc_estimator = EstimatorPYMC(stats_model=pymc_model)

In [None]:
draws = 250  # 500
tune = 250  # 500
cores = 2  # 4

In [None]:
inference_data, inference_summary = pymc_estimator.run_inference(
    draws=draws, tune=tune, cores=cores, target_accept=0.9)

In [None]:
chains = len(inference_data.sample_stats.chain)
draws = len(inference_data.sample_stats.draw)
print("chains:", chains)
print("draws:", draws)
divergent = inference_data.sample_stats["diverging"].values
print("Number of Divergent: %d" % divergent.nonzero()[0].size)
divperc = (divergent.nonzero()[0].size / (chains * draws)) * 100
print("Percentage of Divergent: %.1f" % divperc)
print("Mean tree accept: %.2f" % inference_data.sample_stats.acceptance_rate.values.mean())
print("Mean tree depth: %.2f" % inference_data.sample_stats.tree_depth.values.mean())
print("Sampling time:", str(timedelta(seconds=inference_data.sample_stats.sampling_time)))

In [None]:
init_params = {
    "model_a[0]": sim.model.a[0],
    "model_a[1]": sim.model.a[1],
    "coupling_a": sim.coupling.a[0],
    "nsig": sim.integrator.noise.nsig[0],
    "observation_noise": 0.0,
    "amplitude": 1.0,
    "offset": 0.0
}
plot_posterior_samples(pymc_estimator.inference_data, init_params)

In [None]:
init_params = {
    "model_a_star[0]": 0.0,
    "model_a_star[1]": 0.0,
    "coupling_a_star": 0.0,
    "nsig_star": 0.0,
    "observation_noise_star": 0.0,
    "amplitude_star": 0.0,
    "offset_star": 0.0
}
plot_posterior_samples(pymc_estimator.inference_data, init_params)

In [None]:
posterior_x_obs = pymc_estimator.inference_data.posterior_predictive.x_obs.values.reshape(
    (cores*draws, *X.shape[:-1]))

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(np.percentile(posterior_x_obs[:, :, 0, 0], [2.5, 97.5], axis=0).T, 
              "k", label=r"$V_{95\% PP}(t)$")
axes3[0].plot(X[:, 0, 0, 0], label="V_observed")
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(np.percentile(posterior_x_obs[:, :, 0, 1], [2.5, 97.5], axis=0).T, 
         "k", label=r"$W_{95\% PP}(t)$")
axes3[1].plot(X[:, 0, 1, 0], label="W_observed")
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
pymc_estimator.inference_summary.loc[
    [f"model_a_star[{i}]" for i in range(len(sim.model.a))] + [
        "coupling_a_star", "nsig_star", "amplitude_star", "offset_star", "observation_noise_star"]]

In [None]:
pymc_estimator.inference_summary.loc[
    [f"model_a[{i}]" for i in range(len(sim.model.a))] + [
        "coupling_a", "nsig", "amplitude", "offset", "observation_noise"]]

In [None]:
pymc_estimator.inference_summary.loc[
    pymc_estimator.inference_summary.index.str.contains("x_init_star")].describe()

In [None]:
pymc_estimator.inference_summary.loc[
    pymc_estimator.inference_summary.index.str.contains("dWt_star")].describe()

In [None]:
# inforamtion criteria
# waic, loo = pymc_estimator.information_criteria()

In [None]:
# zscores
posterior_mean = pymc_estimator.get_posterior_mean(
    ["model_a", "coupling_a", "nsig", "amplitude", "offset", "observation_noise"])

posterior_std = pymc_estimator.get_posterior_std(
    ["model_a", "coupling_a", "nsig", "amplitude", "offset", "observation_noise"])

ground_truth = np.array(
    [sim.model.a[0], sim.model.a[1], sim.coupling.a[0], sim.integrator.noise.nsig[0], 1.0, 0.0, 0.0])

zscores = pymc_estimator.compute_zscore(ground_truth, posterior_mean, posterior_std)

In [None]:
# shrinkages
posterior_std = pymc_estimator.get_posterior_std(
    ["model_a_star", "coupling_a_star", "nsig_star", 
     "amplitude_star", "offset_star", "observation_noise_star"])

prior_std = np.ones((7,))

shrinkages = pymc_estimator.compute_shrinkage(prior_std, posterior_std)

In [None]:
f3 = plt.figure(figsize=(12, 8))
plt.plot(shrinkages[:2], zscores[:2], color="blue", linewidth=0, marker="*", markersize=12, label="model_a")
plt.plot(shrinkages[2], zscores[2], color="red", linewidth=0, marker="*", markersize=12, label="coupling_a")
plt.plot(shrinkages[3], zscores[3], color="green", linewidth=0, marker="*", markersize=12, label="nsig")
plt.xlabel("posterior shrinkage")
plt.ylabel("posterior zscore")
plt.legend()
plt.show()

In [None]:
run_id = datetime.now().strftime("%Y-%m-%d_%H%M")
pymc_estimator.inference_data.to_netcdf(
    filename=f"{run_id}_2nodes_test.nc", compress=False)
pymc_estimator.inference_summary.to_json(
    path_or_buf=f"{run_id}_2nodes_test.json"
)