In [None]:
from tvb.simulator.models.linear import Linear
from tvb.simulator.integrators import HeunStochastic
from tvb.simulator.backend.theano import TheanoBackend

import matplotlib.pyplot as plt
import numpy as np
import arviz as az
import pymc3 as pm
import theano.tensor as tt
import theano
import math
from tqdm import tqdm

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# linear model parameter
γ = -0.45

In [None]:
# initialize model instance
linear_model = Linear(gamma=np.asarray([γ]))
linear_model.configure()

# initialize integrator instance
integrator = HeunStochastic(dt=0.1)
integrator.noise.nsig = np.array([0.001])
integrator.noise.configure()
integrator.noise.configure_white(dt=integrator.dt)
integrator.set_random_state(random_state=None)
integrator.configure()
integrator.configure_boundaries(linear_model)

In [None]:
# run simulation for 1 node
simulation_length = 100
stimulus = 0.0
local_coupling = 0.0
current_state = np.random.uniform(low=-1.0, high=1.0, size=[1, 1, 1])
state = current_state
current_step = 0
number_of_nodes = 1
start_step = current_step + 1
node_coupling = np.zeros([1, 1, 1])
n_steps = int(math.ceil(simulation_length / integrator.dt))

X = [current_state.copy()]
for step in range(start_step, start_step + n_steps):
    state = integrator.integrate(state, linear_model, node_coupling, local_coupling, stimulus)
    X.append(state.copy())

X = np.asarray(X)
t = np.linspace(0, simulation_length, n_steps + 1)

In [None]:
fig1 = plt.figure(figsize=(14,8))
plt.plot(t, X[:, 0, 0, 0])
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

### Inference using pymc3
The inference is done using the probabilistic programming library pymc3. The package provides the implementation of gradient based Bayesian inference methods. Here we use the NUTS algorithm.

In [None]:
# create dummy simulator class for theano backend
class sim:
    model = linear_model

In [None]:
# theano backend template
template = """
import theano
import theano.tensor as tt
import numpy as np
import pymc3 as pm 
<%include file="theano-dfuns.py.mako"/>
"""

In [None]:
class pymcInference:
    def __init__(self):
        # initialize pymc model
        self.pymc_model = pm.Model()
        with self.pymc_model:
            # constant parameters
            x0 = X[0]
            self.dt = integrator.dt
            gfun = integrator.noise.gfun(None)[0]

            # inference parameters
            # mean and standard deviation of the priors for model parameters are passed as 
            # a dictionary to the dfun
            gamma_star = pm.Normal(name="gamma_star", mu=0.0, sd=1.0)
            gamma = pm.Deterministic(name="gamma", var=-0.5 + 0.25 * gamma_star)
            self.priors = {"gamma": gamma}
            # self.priors = {"gamma": {"mu": -0.5, "sd": 0.25}}
            
            # dynamic noise to fit time series
            noise = pm.Normal(name="noise", mu=0.0, sd=1.0, shape=X.shape)
            dynamic_noise = pm.Deterministic(name="dynamic_noise", var=gfun * noise)
            
            # observation noise
            BoundedNormal = pm.Bound(pm.Normal, lower=0.0)
            obs_noise = BoundedNormal("obs_noise", mu=0.0, sd=1.0)
            
            # simulate time series using theano scan and theano backend dfun
            x_sim, updates = theano.scan(
                fn=self.scheme, 
                sequences=[dynamic_noise], 
                outputs_info=[x0], 
                n_steps=X.shape[0]
            )

            x_obs = pm.Normal(name="x_obs", mu=x_sim, sd=obs_noise, shape=X.shape, observed=X)
            
    def scheme(self, x_eta, x_prev):
        # scheme function builds theano dfun and computes next state
        kernel = TheanoBackend().build_py_func(template, dict(sim=sim), 
                                               name='dfuns', print_source=True)
        
        dX = tt.zeros(x_prev.shape)
        cX = tt.zeros(x_prev.shape)
        parmat = sim.model.spatial_parameter_matrix
        
        dX = kernel(dX, x_prev, cX, parmat, gamma=self.priors["gamma"])
        x_next = x_prev + self.dt * dX + x_eta
    
        return x_next

In [None]:
# global inference parameters
draws = 250
tune = 250
cores = 2

In [None]:
# initialize pymcInference instance and run inference
pymc_inference = pymcInference()
with pymc_inference.pymc_model:
    trace = pm.sample(draws=draws, tune=tune, cores=cores, target_accept=0.8)
    posterior_predictive = pm.sample_posterior_predictive(trace=trace)
    inference_data = az.from_pymc3(trace=trace, posterior_predictive=posterior_predictive)
    summary = az.summary(inference_data)

In [None]:
# get posterior samples and posterior predicted time series
posterior_gamma = inference_data.posterior.gamma.values.reshape((draws + tune,))
posterior_obs_noise = inference_data.posterior.obs_noise.values.reshape((draws + tune,))
posterior_x_obs = inference_data.posterior_predictive.x_obs.values.reshape((draws + tune, X.shape[0]))

In [None]:
fig2, axes2 = plt.subplots(ncols=2, nrows=1, figsize=(15,5))
axes2[0].hist(posterior_gamma, bins=100);
axes2[0].axvline(γ, color="r", label=r"$\gamma$")
axes2[0].set_title("gamma")

axes2[1].hist(posterior_obs_noise, bins=100);
axes2[1].axvline(0.0, color="r", label=r"observation noise")
axes2[1].set_title("observation noise")

plt.show()

In [None]:
fig3 = plt.figure(figsize=(14,8))
plt.plot(t, np.percentile(posterior_x_obs, [2.5, 97.5], axis=0).T, 
         "k", label=r"$X_{pred}^{95\% PP}(t)$")
plt.plot(t, X[:, 0, 0, 0], label=r"$X_{obs}$")
plt.legend(fontsize=15)
plt.xlabel("time (ms)")
plt.ylabel("states")
plt.show()

In [None]:
# information criteria
waic_criterion = az.waic(inference_data)
loo_criterion = az.loo(inference_data)

In [None]:
print("WAIC: %.2f" % waic_criterion.waic)
print("LOO: %.2f" % loo_criterion.loo)

In [None]:
# summary statistics
summary.loc[["gamma", "obs_noise"]]

In [None]:
divergent = trace["diverging"]
print("Number of Divergent: %d" % divergent.nonzero()[0].size)
divperc = divergent.nonzero()[0].size / (cores * len(trace)) * 100
print("Percentage of Divergent: %.1f" % divperc)
print("Mean tree accept: %.1f" % trace['mean_tree_accept'].mean())
print("Sampling time in minutes: %.0f" % (inference_data.sample_stats.sampling_time / 60))