In [1]:
import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import numpy as np
from numpy.polynomial.hermite import hermgauss
import numpyro
import numpyro.distributions as dist
from numpyro.infer.util import initialize_model

from pathlib import Path
from scipy.stats import norm
from scipy.optimize import minimize

from bsmodel import BSModel
import bridgestan as bs
from klhr import KLHR
from klhr_sinh import KLHRSINH

jax.config.update('jax_enable_x64', True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
bs.set_bridgestan_path(Path.home() / "bridgestan")

In [3]:
bs_model = BSModel(stan_file = "stan/one_exponential.stan", data_file = "stan/one_exponential.json")

In [2]:
def model():
    x = numpyro.sample("x", dist.Exponential(30.0))

def earnings():
    s = numpyro.sample("s", dist.Exponential(0.01))
    b0 = numpyro.sample("b0", dist.StudentT(5.0, 0.0, s))
    b1 = numpyro.sample("b1", dist.StudentT(5.0, 0.0, s))
    sigma = numpyro.sample("sigma", dist.Exponential(0.1))

In [3]:
key = jax.random.PRNGKey(seed = 0)
param_info, potential, postprocess, _ = initialize_model(key, model, dynamic_args = True)

In [4]:
rng = np.random.default_rng(204)
theta0 = rng.normal(size = 1)
rho = rng.normal(size = 1)
rho /= np.linalg.norm(rho)

In [5]:
bs_model.log_density_gradient(theta0, propto=False)

(-72.45078145775592, array([-75.79188037]))

In [5]:
def logdensity(theta):
    return -potential()(theta)
def gradient(theta):
    return jax.grad(logdensity)(theta)
theta_init = param_info.z
_, unflatten = ravel_pytree(theta_init)

In [6]:
logdensity(unflatten(theta0)), gradient(unflatten(theta0))

(Array(-72.45078146, dtype=float64), {'x': Array(-75.79188037, dtype=float64)})

In [6]:
klhr = KLHR(bs_model, theta = theta0)

In [7]:
eta = np.array([0.0, 1.0])

In [8]:
klhr._L(eta, rho)

(np.float64(3083.7338598747006), array([ 3088.07495878, 22824.34808386]))

In [8]:
x, w = hermgauss(16)
xjnp = jnp.asarray(x)
wjnp = jnp.asarray(w)

In [9]:
@jax.jit
def gausshermite_estimate(x, w, rho, eta):
    m, s = eta[0], jnp.exp(eta[1]) + 1e-10
    y = jnp.sqrt(2.0) * s * x + m
    xi = y * rho + theta0
    logp = logdensity(unflatten(xi))
    return w * logp 

In [10]:
ghe = jax.jit(jax.vmap(gausshermite_estimate, in_axes = (0, 0, None, None)))

In [11]:
eta = jnp.array([0.0, 1.0])
rho = jnp.asarray(rho)
@jax.jit
def L(eta, rho):
    return -(jnp.sum(ghe(xjnp, wjnp, rho, eta)) / jnp.sqrt(jnp.pi) + eta[1])

In [12]:
eps = 1e-8
etap = jnp.array([0.0, 1.0 + eps])
L(etap, rho)#- L(eta, rho)) / eps

Array(3083.73408812, dtype=float64)

In [13]:
gradL = jax.jit(jax.grad(L))

In [14]:
gradL(eta, rho)

Array([ 3088.07495878, 22824.34808302], dtype=float64)

In [15]:
5473 / np.sqrt(np.pi)

np.float64(3087.8095907568704)

In [38]:
hessL = jax.hessian(L)
hessL(eta)

Array([[ 1206.79749369,  8917.09434916],
       [ 8917.09434916, 83723.09801715]], dtype=float64)

In [4]:
from scipy.optimize import minimize
import numpy as np

def f(x):
    # objective
    return (x[0]-1)**2 + (x[1]+2)**2 + 0.1*np.sin(3*x[0])

def grad(x):
    return np.array([
        2*(x[0]-1) + 0.3*np.cos(3*x[0]),
        2*(x[1]+2),
    ])

def hess(x):
    return np.array([
        [2 - 0.9*np.sin(3*x[0]), 0.0],
        [0.0, 2.0],
    ])

x0 = np.array([0.0, 0.0])

res = minimize(
    f, x0,
    method="trust-exact",
    jac=grad,
    hess=hess,
    options={"gtol": 1e-2, "maxiter": 200, "disp": True}
)

print(res.x, np.linalg.norm(res.jac))

Optimization terminated successfully.
         Current function value: -0.007898
         Iterations: 4
         Function evaluations: 5
         Gradient evaluations: 5
         Hessian evaluations: 5
[ 1.14386048 -2.        ] 0.00024682719389046426


In [2]:
res

 message: Optimization terminated successfully.
 success: True
  status: 0
     fun: -0.007898310268378332
       x: [ 1.144e+00 -2.000e+00]
     nit: 4
     jac: [ 2.468e-04  0.000e+00]
    nfev: 5
    njev: 5
    nhev: 5
    hess: [[ 2.257e+00  0.000e+00]
           [ 0.000e+00  2.000e+00]]

In [13]:
from scipy.optimize import minimize
import numpy as np
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

def _fj(x):
    # objective
    return (x[0]-1)**2 + (x[1]+2)**2 + 0.1*jnp.sin(3*x[0])

_grad = jax.grad(_fj)
_hess = jax.hessian(_fj)

def fj(x):
    return _fj(jnp.asarray(x)).astype(np.float64)

def grad(x):
    return np.asarray(_grad(jnp.asarray(x)))

def hess(x):
    return np.asarray(_hess(jnp.asarray(x)))

x0 = np.array([0.0, 0.0])

res = minimize(
    fj, x0,
    method="trust-exact",
    jac=grad,
    hess=hess
)

res

 message: Optimization terminated successfully.
 success: True
  status: 0
     fun: -0.007898323763467961
       x: [ 1.144e+00 -2.000e+00]
     nit: 5
     jac: [ 1.547e-08  0.000e+00]
    nfev: 6
    njev: 6
    nhev: 6
    hess: [[ 2.257e+00  0.000e+00]
           [ 0.000e+00  2.000e+00]]

In [14]:
class T():
    def __init__(self):
        self._f = self.__f()

    def __f(self):
        def inner(x):
            return 2 * x
        return inner

In [15]:
t = T()

In [16]:
t._f(2)

4

In [37]:
def f(x):
    return logdensity(unflatten(x * rho))
g = jax.grad(f)

In [44]:
def h(x, rho):
    return logdensity(unflatten(x * rho))
hh = jax.vmap(h, in_axes = (0, None))
def L(xjnp):
    return jnp.sum(hh(xjnp, rho) * wjnp)
g = jax.grad(L)
g(xjnp)

Array([ 2.87863825e-10,  2.33696190e-07,  2.60901605e-05,  8.72183972e-04,
        1.19692979e-02,  7.97545366e-02,  2.81287321e-01,  5.15511537e-01,
        3.34169062e-01, -2.60768418e-01, -3.03409148e-01, -6.66628221e-02,
       -5.30100421e-03, -1.72874676e-04, -1.79308288e-06, -2.36963592e-09],      dtype=float64)

In [18]:
rng.integers(7, size = 10)

array([1, 4, 4, 1, 3, 0, 3, 1, 2, 0])