In [1]:
import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpy as np
from numpy.polynomial.hermite import hermgauss
import numpyro
import numpyro.distributions as dist
from numpyro.infer.util import initialize_model

from pathlib import Path
from scipy.stats import norm
from scipy.optimize import minimize

from bsmodel import BSModel
import bridgestan as bs
from klhr import KLHR
from klhr_sinh import KLHRSINH

jax.config.update('jax_enable_x64', True)

In [2]:
#bs.set_bridgestan_path(Path.home() / "bridgestan")

In [3]:
#bs_model = BSModel(stan_file = "stan/one_exponential.stan", data_file = "stan/one_exponential.json")

In [2]:
def model():
    x = numpyro.sample("x", dist.Exponential(30.0))

In [3]:
key = jax.random.PRNGKey(seed = 0)
param_info, potential, postprocess, _ = initialize_model(key, model, dynamic_args = True)

In [4]:
rng = np.random.default_rng(204)
theta0 = np.zeros(1) # rng.normal(size = 1)
rho = rng.normal(size = 1)
rho /= np.linalg.norm(rho)

In [6]:
bs_model.log_density_gradient(theta0, propto=False)

(-26.598802618337846, array([-29.]))

In [5]:
def logdensity(theta):
    return -potential()(theta)
def gradient(theta):
    return jax.grad(logdensity)(theta)
theta_init = param_info.z
_, unflatten = ravel_pytree(theta_init)

In [6]:
logdensity(unflatten(theta0)), gradient(unflatten(theta0))

(Array(-26.59880262, dtype=float64), {'x': Array(-29., dtype=float64)})

In [7]:
klhr = KLHR(bs_model, theta = theta0)

In [8]:
eta = np.array([0.0, 1.0])

In [9]:
klhr._L(eta, rho)

(1202.3962963053273, array([1980.08488407, 7801.77239206]))

In [7]:
x, w = hermgauss(16)
xjnp = jnp.asarray(x)
wjnp = jnp.asarray(w)

In [32]:
def gausshermite_estimate(x, w, rho, eta):
    m, s = eta[0], jnp.exp(eta[1]) + 1e-10
    y = jnp.sqrt(2.0) * s * x + m
    xi = y * rho + theta0
    logp = logdensity(unflatten(xi))
    return w * logp 

In [33]:
ghe = jax.vmap(gausshermite_estimate, in_axes = (0, 0, None, None))

In [39]:
eta = jnp.array([0.0, 1.0])
rho = jnp.asarray(rho)
def L(eta):
    return -(jnp.sum(ghe(xjnp, wjnp, rho, eta)) / jnp.sqrt(jnp.pi) + eta[1])

In [47]:
eps = 1e-8
etap = jnp.array([0.0, 1.0 + eps])
(L(etap)- L(eta)) / eps

Array(8916.09447535, dtype=float64)

In [46]:
gradL = jax.grad(L)

In [42]:
gradL(eta)

Array([1205.79749369, 8916.09434916], dtype=float64)

In [38]:
hessL = jax.hessian(L)
hessL(eta)

Array([[ 1206.79749369,  8917.09434916],
       [ 8917.09434916, 83723.09801715]], dtype=float64)