Skip to content

Commit

Permalink
Use Cholesky decomp instead of inverting kernel (#1688)
Browse files Browse the repository at this point in the history
* Use Cholesky decomp instead of inverting kernel

* Add a `use_cholesky` option

* Use dash in args for consistency
  • Loading branch information
DanWaxman committed Nov 28, 2023
1 parent 31c175b commit b16741c
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,28 @@ def run_inference(model, args, rng_key, X, Y):


# do GP prediction for a given set of hyperparameters. this makes use of the well-known
# formula for gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise):
# formula for Gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))

# since K_xx is symmetric positive-definite, we can use the more efficient and
# stable Cholesky decomposition instead of matrix inversion
if use_cholesky:
K_xx_cho = jax.scipy.linalg.cho_factor(k_XX)
K = k_pp - jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, k_pX.T))
mean = jnp.matmul(k_pX, jax.scipy.linalg.cho_solve(K_xx_cho, Y))
else:
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise
Expand Down Expand Up @@ -148,7 +158,7 @@ def main(args):
)
means, predictions = vmap(
lambda rng_key, var, length, noise: predict(
rng_key, X, Y, X_test, var, length, noise
rng_key, X, Y, X_test, var, length, noise, use_cholesky=args.use_cholesky
)
)(*vmap_args)

Expand Down Expand Up @@ -184,6 +194,7 @@ def main(args):
type=str,
choices=["median", "feasible", "value", "uniform", "sample"],
)
parser.add_argument("--no-cholesky", dest="use_cholesky", action="store_false")
args = parser.parse_args()

numpyro.set_platform(args.device)
Expand Down

0 comments on commit b16741c

Please sign in to comment.