Skip to content

Commit

Permalink
Remove x64 in gp example (#190)
Browse files Browse the repository at this point in the history
* remove x64 in gp example

* fix mistake in the order of solve_triangular

* skip nan predictions

* reduce jitter to default 1e-6

* clip variance prediction
  • Loading branch information
fehiepsi authored and martinjankowiak committed Jun 5, 2019
1 parent 4f52179 commit ba1fb67
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
19 changes: 8 additions & 11 deletions examples/gp.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import matplotlib
matplotlib.use('Agg') # noqa: E402
import matplotlib.pyplot as plt
import jax

import argparse

import matplotlib
import matplotlib.pyplot as plt
import numpy as onp
from jax import vmap

import jax
import jax.numpy as np
import jax.random as random
from jax import vmap

import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import mcmc

from jax.config import config
# we use double precision to minimize any possible numerical instabilities in jax linear algebra
config.update('jax_enable_x64', True)
matplotlib.use('Agg') # noqa: E402

"""
In this example we show how to use NUTS to sample from the posterior
Expand All @@ -26,7 +23,7 @@


# squared exponential kernel with diagonal noise term
def kernel(X, Z, var, length, noise, jitter=1.0e-5, include_noise=True):
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
deltaXsq = np.power((X[:, None] - Z) / length, 2.0)
k = var * np.exp(-0.5 * deltaXsq)
if include_noise:
Expand Down Expand Up @@ -65,7 +62,7 @@ def predict(rng, X, Y, X_test, var, length, noise):
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = np.linalg.inv(k_XX)
K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
sigma_noise = np.sqrt(np.diag(K)) * jax.random.normal(rng, (X_test.shape[0],))
sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng, X_test.shape[:1])
mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
Expand Down
5 changes: 4 additions & 1 deletion numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jax.interpreters import ad, batching
from jax.lib import xla_bridge
from jax.numpy.lax_numpy import _promote_args_like
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import gammaln
from jax.util import partial

Expand Down Expand Up @@ -391,7 +392,9 @@ def cholesky_inverse(matrix):
# which is more numerically stable.
# Refer to:
# https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
return np.swapaxes(np.linalg.inv(np.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1]), -2, -1)
tril_inv = np.swapaxes(np.linalg.cholesky(matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1)
identity = np.broadcast_to(np.identity(matrix.shape[-1]), tril_inv.shape)
return solve_triangular(tril_inv, identity, lower=True)


def entr(p):
Expand Down

0 comments on commit ba1fb67

Please sign in to comment.