
For Gaussians with $N(0,1)$ priors, we can calculate this as (see Kingma Welling 2013 VAE paper in the appendix)
$$
    KL(q(\mu_1|\sigma_1) || N(0,1)) = -\frac{1}{2}  (1 + log(\sigma^2) - \mu^2 - \sigma^2) 
$$

In [3]:
import numpy as np
import distrax

mu = 2.
s  = 3.

distrax.Normal(loc=2.0,scale=3.0).kl_divergence(distrax.Normal(loc=0.0,scale=1.0))
print(-(1. + np.log(s**2) - mu**2 - s**2)/2.)

4.9013877
4.90138771133189


In [74]:
import distrax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)
key1,key2=jax.random.split(key)
X = jnp.concatenate([jnp.ones(100).reshape(-1,1),jax.random.normal(key1,shape=(100,1))],axis=-1)
beta=jnp.array([2,1]).reshape(-1,1)
y = X@beta + 0.1 * jax.random.normal(key,shape=(100,1))

In [100]:
sigma=0.1

def log_likelihood(test_point):
    #log_prior_eta=distrax.Normal(0.,1.).log_prob(test_point['eta']).sum()
    #log_prior_mu=distrax.Normal(0.,1.).log_prob(test_point['mu']).sum()
    #log_prior_tau=distrax.Transformed(distrax.Normal(loc=0., scale=1.),
    #                                distrax.Lambda(lambda x:jnp.exp(x))).log_prob(test_point['tau']).sum()
    loc=(test_point['mu']+test_point['tau']*test_point['eta']).reshape(-1,1)
    log_like=distrax.Independent(distrax.Normal(loc=X@loc,scale=sigma)).log_prob(y).sum()
    return log_like


def elbo(test_point,key):
    key1,key2=jax.random.split(key)
    q_dist=distrax.Normal(test_point['mu'],test_point['tau'])
    test_point.update({'eta':distrax.Normal(0.,1.).sample(seed=key1,sample_shape=(2,))})
    loss=log_likelihood(test_point)+q_dist.kl_divergence(distrax.Normal(0.,1.)).sum()
    return loss,key2



test_point={'mu':jnp.array([0.0,0.0]),'tau':jnp.array([1.0,1.0]),'eta':jnp.array([1.0,1.0])}
print(elbo(test_point,key2))

(Array(-2043.5426, dtype=float32), Array([ 754422081, 3987528881], dtype=uint32))


In [None]:
import bayeux as bx 

def transform_fn(test_point):
  return {'mu': test_point['mu'], 'tau': jnp.exp(test_point['tau']),
          'eta':test_point['eta']}

bx_jax = bx.Model(
    log_density=elbo,
    test_point=test_point,
    transform_fn=transform_fn)

idata_jax=bx_jax.mcmc.numpyro_nuts(seed=jax.random.key(0))