The pure JAX BNM simulation tool [vjax](https://github.com/ins-amu/vbjax/tree/main/vbjax) is used to compare the workflow of PyMC and NumPyro on the same model. The model is a simple MPR based networked simulation.
Everything was run with Python 3.11 and the latest versions of the packages.

# A simple MPR based simulation in vbjax

The model is basically the Readme example of all to all coupled MPR nodes. The simulation is done with the default parameters. The SDE version is used to generate data while the ODE version is used later in statistical model fitting. This procedure is inspired by [this](https://www.pymc.io/projects/examples/en/latest/time_series/Euler-Maruyama_and_SDEs.html) PyMC example, which happens to be provided by the Marseilles theoretical neuroscience group. 

In [None]:
import vbjax as vb
import jax.numpy as jnp
import numpy as np

In [None]:
# Define a network of all to all coupled MPR nodes; Usage of SCs is not easily supported yet as it seems.
def network(x, p):
    c = 0.03*x.sum(axis=1)
    return vb.mpr_dfun(x, c, p)

In [None]:
N = 5 # number of regions
len = 500 # number of samples
_, loop_sde = vb.make_sde(dt=0.01, dfun=network, gfun=0.1) # loop_sde is jit compiled via jax
_, loop_ode = vb.make_ode(dt=0.01, dfun=network) 
zs = vb.randn(len, 2, N) # noise

In [None]:
# Run simulation with default parameters
xs = loop_sde(zs[0], zs[1:], vb.mpr_default_theta) # loop_sde(ics, noise, parameters)
xo = loop_ode(zs[0], np.linspace(0,len, num = len -1), vb.mpr_default_theta) # loop_ode(ics, time, parameters)

In [None]:
vb.plot_states(xs, 'rV', show=True)

In [None]:
vb.plot_states(xo, 'rV', show=True)

# Simple roundtrip estimation

Recover the default parameters and estimate noise from the initial SDE simulation using HMC in NumPyro and PyMC. This is more a proof of concept/ getting familiar with the frameworks than an especially interesting problem. 

## Disentangle relations

```{mermaid}
flowchart LR
    NumPyro --> JAX
    
```

```{mermaid}
flowchart LR
    PyMc --> PyTensor <--> JAX & Numba & C
       
```

```{mermaid}
flowchart LR
    JAX --> XLA --> CPU & GPU & TPU    
```

tldr: Models/Functions in JAX can be wrapped and used with PyMC. PyTensor gradient graphs can be converted to JAX and vice versa. As NumPyro is pure JAX it can be used from PyMC directly eg as sampler. Staying with PyTensor offers increased  stability, easier debugging (selfproclaimed, not the experience I made) and  mutable graphs but more boilerplate and mental overhead.

## NumPyro

In [None]:
import os

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from jax.random import PRNGKey
import jax.random as rng
import matplotlib.pyplot as plt

In [None]:
def model(N, len, data=None):
    # we could sample the ics as well, but for simplicity we don't
    ics = zs[0]

    # parameters: tau, I, Delta, J, eta, cr, cv for mpr
    theta = numpyro.sample(
        "theta",
        dist.Normal(
            loc=jnp.zeros(7),
            scale=jnp.ones(7) * 15,
        ),
    )
    curr_theta =  vb.MPRTheta(
        tau = theta[0],
        I = theta[1],
        Delta = theta[2],
        J = theta[3],
        eta = theta[4],
        cr = theta[5],
        cv = theta[6]
    )
    # Predict using the ode model
    x = loop_ode(ics, np.linspace(0, len, num = len-1), curr_theta) # loop_ode(ics, time, parameters)

    # measurement errors - sample the standard deviation per node
    # sigma = numpyro.sample("sigma", dist.LogNormal(0, 1).expand([N]))
    sigma = numpyro.sample("sigma", dist.LogNormal(0, 1))#.expand([N]))
    # measured populations
    numpyro.sample("y", dist.Normal(x, sigma), obs=data)

In [None]:
#| output: true
# use dense_mass for better mixing rate
mcmc = MCMC(
    NUTS(model, dense_mass=True),
    num_warmup=200,
    num_samples=500,
    num_chains=1,
    progress_bar=False #if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(PRNGKey(1), N=5, len = 500, data=xs)

In [None]:
vb.mpr_default_theta

In [None]:
mcmc.print_summary()

In [None]:
# Get the posterior samples
samples = mcmc.get_samples()

# Extract the mean value
theta_mean = jnp.mean(samples['theta'], axis = 0)

# Change a parameter

mpr_changed = vb.MPRTheta(
        tau = theta_mean[0],
        I = theta_mean[1],
        Delta = theta_mean[2],
        J = theta_mean[3],
        eta = theta_mean[4],
        cr = theta_mean[5],
        cv = theta_mean[6]
    )
xs_changed = loop_sde(zs[0], zs[1:], mpr_changed)
# Plots xs and xs_changed in the same plot with different color
names = 'rV'
for i in range(xs.shape[1]):
        plt.subplot(xs.shape[1], 1, i+1)
        plt.plot(xs[:, i], 'k', alpha=0.3)
        plt.plot(xs_changed[:, i], 'r', alpha=0.3)
        plt.ylabel(names[i])
        plt.xlabel('time')
        plt.grid(1)
        
plt.tight_layout()

## PyMC

In [None]:
import pytensor
import pytensor.tensor as pt
from pytensor.graph import Apply, Op
from pytensor.link.jax.dispatch import jax_funcify

import jax

import pymc as pm
import pymc.sampling.jax

In [None]:
# This can be used to enable float64 in JAX if set True but is has to happen earlier
from jax.config import config
config.update("jax_enable_x64", False)

In [None]:
# Instead we set the floatX type of pytensor to float32 to make it work with the JAX default 
dtype = 'float32'   
pytensor.config.floatX = dtype

Wrapping a JAX function in a black box style function for PyMC, which requires to define the vector jacobian product (vjp) manually. VBJAX uses named tuples to store parameters which is convenient but has to be wrapped to match the PyMC arraylike interface.


In [None]:
def jax_fun(params):
    curr_theta =  vb.MPRTheta(
        tau = params[0],
        I = params[1],
        Delta = params[2],
        J = params[3],
        eta = params[4],
        cr = params[5],
        cv = params[6]
    )
    # xs = loop_sde(zs[0], zs[1:], curr_theta) # loop_sde(ics, noise, parameters)
    xo = loop_ode(zs[0], jnp.linspace(0, 500, num = 500-1), curr_theta) # loop_ode(ics, noise, parameters)

    return xo

# JAX functions can be compiled multiple times which makes for a convenient incremental building of complex functions
jitted_jax_fun = jax.jit(jax_fun)

In [None]:
jitted_jax_fun(np.array(vb.mpr_default_theta, dtype = dtype),)

In [None]:
def vjp_jax_fun(params, gz):
   _, vjp_fn = jax.vjp(jax_fun, params)
   return vjp_fn(gz)[0] 

jitted_vjp_jax_fun = jax.jit(vjp_jax_fun)

In [None]:
jitted_vjp_jax_fun(jnp.array(vb.mpr_default_theta, dtype = dtype), xs)

After that both functions need to be wrapped in a PyTensor compatible OP class that implements:

* `make_node`: Creates an Apply node that holds together the symbolic inputs and outputs of our operation
* `perform`: Python code that returns the evaluation of our operation, given concrete input values
* `grad`: Returns a PyTensor symbolic graph that represents the gradient expression of an output cost wrt to its inputs

In [None]:
class SolOp(Op):
    def make_node(self, params):
        inputs = [pt.as_tensor_variable(params, dtype = dtype)]
        outputs = [pt.tensor3(dtype=dtype)]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (params,) = inputs
        result = jitted_jax_fun(params)
        outputs[0][0] = np.asarray(result, dtype=dtype)

    def grad(self, inputs, output_gradients):
        (params,) = inputs
        (gz,) = output_gradients
        return [vjp_sol_op(params, gz)]
    
class VJPSolOp(Op):
    def make_node(self, params, gz):
        inputs = [pt.as_tensor_variable(params), pt.as_tensor_variable(gz)]
        outputs = [inputs[0].type()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (params, gz) = inputs
        result = jitted_vjp_jax_fun(params, gz)
        outputs[0][0] = np.asarray(result, dtype=dtype)

sol_op = SolOp()
vjp_sol_op = VJPSolOp()

In [None]:
#| output: false
# verify grads, this function helps a lot while debugging as errors in the Op definition easily cause a segfault
# pytensor.gradient.verify_grad(sol_op, (jnp.array(vb.mpr_default_theta, dtype = dtype),), rng=np.random.default_rng())
pytensor.gradient.verify_grad(sol_op, (jnp.array(vb.mpr_default_theta, dtype = dtype),), mode='DebugMode', rng=np.random.default_rng())

In [None]:
# Register the function to make it available to the PyTensor Linker 
@jax_funcify.register(SolOp)
def sol_op_funcify(op, **kwargs):
    return sol_op

@jax_funcify.register(VJPSolOp)
def vjp_sol_op_funcify(op, **kwargs):
    return vjp_sol_op

In [None]:
with pm.Model() as model_pymc:
    params = pm.Normal("params", 0, 15, shape = 7)
    xo = sol_op(params)
    noise = pm.HalfNormal("noise")#, shape = 5)
    llike = pm.Normal("llike", mu=xo, sigma=noise, observed=xs)

In [None]:
with model_pymc:    
    trace = pm.sample(500, tune=200, chains = 1)

In [None]:
# # Sample with numpyro - not working yet. It compiles but then gets NaNs from somewhere
# with model_pymc:
#     samples = pm.sampling.jax.sample_numpyro_nuts(2000, tune=500, chains = 2, progressbar = False)

In [None]:
import arviz as az
az.plot_trace(trace)

In [None]:
az.plot_forest(trace)
# Plot true values into the forest plot in a hacky way
plt.plot(vb.mpr_default_theta, 2.525* np.linspace(1,7, num = 7), "x", color="r", alpha=0.4)
vb.mpr_default_theta

In [None]:
theta_post = np.mean(trace.posterior["params"][0,:,:], axis=0)#[:, 0, :]
# Plot theta with variable names on x axis
plt.plot(theta_post, "o", color="k", ms=10)
plt.plot(vb.mpr_default_theta, "x", color="r", label="True values")

In [None]:
theta_mean = np.asanyarray(theta_post, dtype = dtype)

# Change a parameter

mpr_changed = vb.MPRTheta(
        tau = theta_mean[0],
        I = theta_mean[1],
        Delta = theta_mean[2],
        J = theta_mean[3],
        eta = theta_mean[4],
        cr = theta_mean[5],
        cv = theta_mean[6]
    )
xs_changed = loop_sde(zs[0], zs[1:], mpr_changed)

# Plots xs and xs_changed in the same plot with different color
names = 'rV'
for i in range(xs.shape[1]):
        plt.subplot(xs.shape[1], 1, i+1)
        plt.plot(xs[:, i], 'k', alpha=0.3)
        plt.plot(xs_changed[:, i], 'r', alpha=0.3)
        plt.ylabel(names[i])
        plt.xlabel('time')
        plt.grid(1)
        
plt.tight_layout()

## Personal Fazit

* Dev Time: The NumPyro Example took less than 1/10th of the time. The reason is probably the two language problem created by PyTensor + JAX  which creates several points of failure in interoperability that are hard to debug. Also looking up PyMC tends to be hard as the API from v3 to v5 seems to have changed a lot resulting in outdated examples etc. Probably diminishes with more experience though the needed boilerplate overhead remains.   

* Performance: The NumPyro example is faster but out of the box. This can probably adjusted with experience. Also scaling to more complex models is unknown. Sensible benchmarks are actually hard.

* Documentation: Both are good enough.
