In [None]:
import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
import arviz as az
import matplotlib.pyplot as plt

from tvb.simulator.lab import *
from tvb_inversion.pymc3.prior import Pymc3Prior
from tvb_inversion.pymc3.stats_model import Pymc3Model
from tvb_inversion.pymc3.inference import EstimatorPYMC, plot_posterior_samples

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#conn = connectivity.Connectivity.from_file()
conn = connectivity.Connectivity()
conn.weights = np.array([[0., 2.], [2., 0.]])
conn.region_labels = np.array(["R1", "R2"])
conn.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
conn.tract_lengths = np.array([[0., 2.5], [2.5, 0.]])
conn.configure()

sim = simulator.Simulator(
    model=models.oscillator.Generic2dOscillator(),
    connectivity=conn,
    coupling=coupling.Difference(),
    integrator=integrators.HeunStochastic(
        dt=1.0,
        noise=noise.Additive(
            nsig=np.array([0.003]),
            noise_seed=42
        )
    ),
    monitors=[monitors.Raw()],
    simulation_length=250
)

In [None]:
sim.configure()

In [None]:
(t, X), = sim.run()

In [None]:
X.shape

In [None]:
f1 = plt.figure(figsize=(18, 10))
plt.plot(X[:, 0, :, 0]);

In [None]:
model = pm.Model()
with model:
    #a_model_star = pm.Normal(name="a_model_star", mu=0.0, sd=1.0)
    #a_model = pm.Deterministic(name="a_model", var=-2.0 + 1.0 * a_model_star)
    
    a_coupling_star = pm.Normal(name="a_coupling_star", mu=0.0, sd=1.0)
    a_coupling = pm.Deterministic(name="a_coupling", var=0.1 + 0.05 * a_coupling_star)
    
    BoundedNormal = pm.Bound(pm.Normal, lower=0.0)
    nsig_star = BoundedNormal(name="nsig_star", mu=0.0, sd=1.0)
    nsig = pm.Deterministic(name="nsig", var=0.003 + 0.001 * nsig_star)
    
    #noise_gfun_star = BoundedNormal(name="noise_gfun_star", mu=0.0, sd=1.0)
    #noise_gfun = pm.Deterministic(name="noise_gfun", var=0.07 + 0.1 * noise_gfun_star)
    #noise_gfun = sim.integrator.noise.gfun(None)[0]
    
    #noise_star = pm.Normal(name="noise_star", mu=0.0, sd=1.0, shape=X.shape[:-1])
    #dynamic_noise = pm.Deterministic(name="dynamic_noise", var=noise_gfun * noise_star)
    
    #observation_noise = pm.HalfNormal(name="observation_noise", sigma=0.05)

In [None]:
#prior = Pymc3Prior(
#    names=["coupling.a", "dynamic_noise", "observation_noise"], 
#    dist=[a_coupling, dynamic_noise, observation_noise]
#)

prior = Pymc3Prior(
    names=["coupling.a", "integrator.noise.nsig"], 
    dist=[a_coupling, nsig]
)

In [None]:
pymc_model = Pymc3Model(sim=sim, params=prior, model=model)

In [None]:
pymc_estimator = EstimatorPYMC(stats_model=pymc_model, observation=X)

In [None]:
draws = 250
tune = 250
cores = 2

In [None]:
inference_data = pymc_estimator.run_inference(draws, tune, cores, target_accept=0.9)

In [None]:
init_params = {
    #"a_model": sim.model.a[0],
    "a_coupling": sim.coupling.a[0],
    "nsig": sim.integrator.noise.nsig[0],
    "observation_noise": 0.0
}

In [None]:
plot_posterior_samples(inference_data, init_params)

In [None]:
posterior_x_obs = inference_data.posterior_predictive.x_obs.values.reshape((cores*draws, *X.shape[:-1]))

In [None]:
f2, axes2 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes2[0].plot(np.percentile(posterior_x_obs[:, :, 0, 0], [2.5, 97.5], axis=0).T, 
              "k", label=r"$V_{95\% PP}(t)$")
axes2[0].plot(X[:, 0, 0, 0], label="V_observed")
axes2[0].legend(fontsize=16)
axes2[0].set_xlabel("time (ms)", fontsize=16)
axes2[0].tick_params(axis="both", labelsize=16)

axes2[1].plot(np.percentile(posterior_x_obs[:, :, 0, 1], [2.5, 97.5], axis=0).T, 
         "k", label=r"$W_{95\% PP}(t)$")
axes2[1].plot(X[:, 0, 1, 0], label="W_observed")
axes2[1].legend(fontsize=16)
axes2[1].set_xlabel("time (ms)", fontsize=16)
axes2[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
pymc_estimator.inference_summary.loc[["a_coupling", "nsig", "observation_noise"]]

In [None]:
divergent = inference_data.sample_stats.diverging.values
print("Number of Divergent %d" % divergent.nonzero()[0].size)
divperc = divergent.nonzero()[0].size / (
    len(inference_data.sample_stats.chain) * len(inference_data.sample_stats.draw)) * 100
print("Percentage of Divergent %.1f" % divperc)
print("Mean tree accept %.1f" % inference_data.sample_stats.acceptance_rate.mean())

In [None]:
print("Sampling time in minutes:", inference_data.sample_stats.sampling_time // 60)

In [None]:
pymc_estimator.inference_data.to_netcdf(filename="pymc3_data/20221206_couplinga_nsig.nc", compress=False)