In [1]:
import pymc as pm
import numpy as np

y = np.array([28, 8, -3, 7, -1, 1, 18, 12],dtype=np.float32)
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18],dtype=np.float32)
J = len(y)




# Model

In [4]:
with pm.Model() as model:
    # Latent pooled effect size
    mu = pm.Normal("mu", 0, sigma=1e6)
    obs = pm.Normal("obs", mu, sigma=sigma, observed=y)

In [5]:
import bayeux as bx

bx_model = bx.Model.from_pymc(model)

2024-08-20 10:43:27.414725: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
bx_model.log_density(bx_model.test_point)

Array(-46.18996085, dtype=float64)

In [7]:
bx_model.test_point

{'mu': array(0.)}

In [8]:
test_point=0.
func = model.logp_dlogp_function()
func.set_extra_values({})
ll,gr=func(np.array([test_point]))
print(ll,gr)

-46.189960847437575 [0.46353275]


In [9]:
import distrax
from jax import grad
import jax.numpy as jnp

def log_likelihood(test_point):
    log_prior=distrax.Normal(0.,1e6).log_prob(test_point['mu'])
    log_like=distrax.Independent(distrax.Normal(test_point['mu'],sigma)).log_prob(y)
    return log_prior+jnp.sum(log_like)

dlog_p = grad(log_likelihood)

print(log_likelihood(bx_model.test_point),dlog_p(bx_model.test_point))

-46.189960184308596 {'mu': Array(0.46353275, dtype=float64)}


# Inference

In [10]:
with model:
    i_data_pmc=pm.sample(2000)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu]
  self.pid = os.fork()


Output()

  self.pid = os.fork()


Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 3 seconds.


In [11]:
import jax 

idata_bx = bx_model.mcmc.numpyro_nuts(seed=jax.random.key(0))



sample: 100%|██████████| 1500/1500 [00:05<00:00, 252.67it/s]


In [12]:
bx_jax = bx.Model(
    log_density=log_likelihood,
    test_point=bx_model.test_point)

idata_jax=bx_jax.mcmc.numpyro_nuts(seed=jax.random.key(0))

sample: 100%|██████████| 1500/1500 [00:05<00:00, 262.48it/s]


# Diagnostics

In [13]:
import arviz as az

az.summary(i_data_pmc)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,7.792,4.046,0.234,15.394,0.073,0.052,3064.0,5222.0,1.0


In [14]:
az.summary(idata_bx)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,7.796,4.096,-0.097,15.195,0.074,0.052,3084.0,4420.0,1.0


In [15]:
az.summary(idata_jax)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
mu,7.822,4.069,-0.029,15.113,0.075,0.053,2961.0,4317.0,1.0


In [98]:
log_normal=distrax.Transformed(distrax.Normal(0.,1.),distrax.Lambda(lambda x : jnp.exp(x)))
y=log_normal.sample(seed=jax.random.key(0))

In [99]:
distrax.Normal(0.,1.).log_prob(jnp.log(y))-jnp.log(y)

Array(-0.44210142, dtype=float64)

In [100]:
log_normal.log_prob(y)

Array(-0.44210142, dtype=float64)

In [107]:
x,ld=distrax.Lambda(lambda x : jnp.exp(x)).inverse_and_log_det(y)

In [108]:
distrax.Normal(0.,1.).log_prob(x)+ld

Array(-0.44210142, dtype=float64)