In [None]:
import pymc3 as pm
import numpy as np
import theano
import theano.tensor as tt
import arviz as az
import matplotlib.pyplot as plt

from tvb.simulator.lab import *
from tvb_inversion.pymc3.prior import Pymc3Prior
from tvb_inversion.pymc3.stats_model import Pymc3Model
from tvb_inversion.pymc3.inference import EstimatorPYMC

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
conn = connectivity.Connectivity.from_file()
#conn = connectivity.Connectivity()
#conn.weights = np.array([[0., 2.], [2., 0.]])
#conn.region_labels = np.array(["R1", "R2"])
#conn.centres = np.array([[0.1, 0.1, 0.1], [0.2, 0.1, 0.1]])
#conn.tract_lengths = np.array([[0., 2.5], [2.5, 0.]])
#conn.configure()

sim = simulator.Simulator(
    model=models.oscillator.Generic2dOscillator(),
    connectivity=conn,
    coupling=coupling.Difference(),
    integrator=integrators.HeunStochastic(
        dt=1.0,
        noise=noise.Additive(
            nsig=np.array([1e-4]),
            noise_seed=42
        )
    ),
    monitors=[monitors.Raw()],
    simulation_length = 250
)

In [None]:
sim.configure()

In [None]:
(t, X), = sim.run()

In [None]:
X.shape

In [None]:
plt.plot(X[:, 0, :, 0]);

In [None]:
model = pm.Model()
with model:
    a_model_star = pm.Normal(name="a_model_star", mu=0.0, sd=1.0)
    a_model = pm.Deterministic(name="a_model", var=-2.0 + 1.0 * a_model_star)
    
    a_coupling_star = pm.Normal(name="a_coupling_star", mu=0.0, sd=1.0)
    a_coupling = pm.Deterministic(name="a_coupling", var=0.1 + 0.05 * a_coupling_star)

In [None]:
from tvb_inversion.pymc3.stats_model_builder import DefaultStochasticPymc3ModelBuilder
model_builder = DefaultStochasticPymc3ModelBuilder(simulator, model=model, observation=X)
prior = model_builder.set_default_prior(def_std=0.1)
model_builder.configure()
stats_model = model_builder.build()

In [None]:
prior = Pymc3Prior(
    names=["model.a", "coupling.a", "dynamic_noise", "global_noise"], 
    dist=[a_model, a_coupling, dynamic_noise, global_noise]
)

In [None]:
pymc_model = Pymc3Model(sim=sim, params=prior, model=model)

In [None]:
pymc_estimator = EstimatorPYMC(stats_model=pymc_model, observation=X)

In [None]:
draws = 500
tune = 500
cores = 2

In [None]:
inference_data = pymc_estimator.run_inference(draws, tune, cores, target_accept=0.9)

In [None]:
init_params = {
    #"a_model": sim.model.a[0],
    "a_coupling": sim.coupling.a[0],
    "noise_gfun": sim.integrator.noise.gfun(None)[0],
    "global_noise": 0.0
}

In [None]:
pymc_estimator.plot_posterior_samples(init_params)

In [None]:
posterior_x_obs = inference_data.posterior_predictive.x_obs.values.reshape((cores*draws, *X.shape[:-1]))

In [None]:
f3, axes3 = plt.subplots(nrows=2, ncols=1, figsize=(18,15))
axes3[0].plot(np.percentile(posterior_x_obs[:, :, 0, 0], [2.5, 97.5], axis=0).T, 
              "k", label=r"$V_{95\% PP}(t)$")
axes3[0].plot(X[:, 0, 0, 0], label="V_observed")
axes3[0].legend(fontsize=16)
axes3[0].set_xlabel("time (ms)", fontsize=16)
axes3[0].tick_params(axis="both", labelsize=16)

axes3[1].plot(np.percentile(posterior_x_obs[:, :, 0, 1], [2.5, 97.5], axis=0).T, 
         "k", label=r"$W_{95\% PP}(t)$")
axes3[1].plot(X[:, 0, 1, 0], label="W_observed")
axes3[1].legend(fontsize=16)
axes3[1].set_xlabel("time (ms)", fontsize=16)
axes3[1].tick_params(axis="both", labelsize=16)

plt.show()

In [None]:
pymc_estimator.inference_summary.loc[["a_model", "a_coupling", "global_noise", "noise_gfun"]]

In [None]:
pymc_estimator.inference_data.to_netcdf(filename="pymc3_data/test1.nc", compress=False)