Skip to content

Commit

Permalink
Merge branch 'main' into mc/issue-59
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewcarbone committed Mar 29, 2024
2 parents 7d8156b + 541caea commit 3b7ce70
Show file tree
Hide file tree
Showing 16 changed files with 368 additions and 281 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ gp_model = gpax.ExactGP(1, kernel='RBF')
# Run Hamiltonian Monte Carlo to obtain posterior samples for the GP model parameters
gp_model.fit(rng_key, X, y) # X and y are numpy arrays with dimensions (n, d) and (n,)
```
In the fully Bayesian mode, we get a pair of predictive mean and covariance for each Hamiltonian Monte Carlo sample containing the GP parameters (in this case, the RBF kernel hyperparameters and model noise). Hence, a prediction on new inputs with a trained GP model returns the center of the mass of all the predictive means (```y_pred```) and samples from multivariate normal distributions for all the pairs of predictive means and covariances (```y_sampled```).
In the fully Bayesian mode, we get a pair of predictive mean and covariance for each Hamiltonian Monte Carlo sample containing the GP parameters (in this case, the RBF kernel hyperparameters and model noise). Hence, a prediction on new inputs with a trained GP model returns the center of the mass of all the predictive means (```posterior_mean```) and samples from multivariate normal distributions for all the pairs of predictive means and covariances (```f_samples```).
```python3
y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test)
posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test)
```

<img src = "https://user-images.githubusercontent.com/34245227/167945293-8cb5b88a-1f64-4f7d-95ab-26863b90d1e5.jpg" height="60%" width="60%">
Expand Down Expand Up @@ -96,7 +96,7 @@ sgp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise, mean_fn_prior=pi
# Run MCMC to obtain posterior samples
sgp_model.fit(rng_key, X, y)
# Get GP prediction on new/test data
y_pred, y_sampled = sgp_model.predict(rng_key_predict, X_test)
posterior_mean, f_samples = sgp_model.predict(rng_key_predict, X_test)
```

![GP_vs_sGP2](https://github.com/ziatdinovmax/gpax/assets/34245227/89de341c-f00c-468c-afe6-c0b1c1140725)
Expand Down Expand Up @@ -150,7 +150,7 @@ Note that X has (N, D+1) dimensions where the last column contains task/fidelity
X_unmeasured2 = np.column_stack((X_full_range, np.ones_like(X_full_range)))

# Make a prediction with the trained model
y_mean2, y_sampled2 = model.predict(key2, X_unmeasured2, noiseless=True)
posterior_mean2, f_samples2 = model.predict(key2, X_unmeasured2, noiseless=True)
```

![GP_vs_MTGP](https://github.com/ziatdinovmax/gpax/assets/34245227/5a36d3cd-c904-4345-abc3-b1bea5025cc8)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ GPax is a small Python package for physics-based Gaussian processes (GPs) built
:caption: Package Content

models
hypo
acquisition
kernels
priors
hypo
utils

.. toctree::
Expand Down
10 changes: 10 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,13 @@ Multi-Task Learning
:undoc-members:
:member-order: bysource
:show-inheritance:

Structured Probabilistic Models
-------------------------------
.. autoclass:: gpax.models.spm.sPM
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

23 changes: 23 additions & 0 deletions docs/source/priors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Priors
======

.. autofunction:: gpax.priors.normal_dist

.. autofunction:: gpax.priors.lognormal_dist

.. autofunction:: gpax.priors.halfnormal_dist

.. autofunction:: gpax.priors.gamma_dist

.. autofunction:: gpax.priors.uniform_dist

.. autofunction:: gpax.priors.auto_normal_priors

.. autofunction:: gpax.priors.auto_lognormal_priors

.. autofunction:: gpax.priors.auto_normal_kernel_priors

.. autofunction:: gpax.priors.auto_lognormal_kernel_priors



31 changes: 4 additions & 27 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
@@ -1,40 +1,17 @@
Utilities
=========

Priors
------
Automatic function setters
--------------------------

.. autofunction:: gpax.utils.normal_dist
.. autofunction:: gpax.utils.set_fn

.. autofunction:: gpax.utils.lognormal_dist

.. autofunction:: gpax.utils.halfnormal_dist

.. autofunction:: gpax.utils.gamma_dist

.. autofunction:: gpax.utils.uniform_dist

.. autofunction:: gpax.utils.place_normal_prior

.. autofunction:: gpax.utils.place_lognormal_prior

.. autofunction:: gpax.utils.place_halfnormal_prior

.. autofunction:: gpax.utils.place_uniform_prior

.. autofunction:: gpax.utils.place_gamma_prior
.. autofunction:: gpax.utils.set_kernel_fn


Other utilities
---------------

.. autoclass:: gpax.models.spm.sPM
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

.. autofunction:: gpax.utils.dviz

.. autofunction:: gpax.utils.get_keys
Expand Down
1 change: 1 addition & 0 deletions gpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

__all__ = [
"priors",
"utils",
"kernels",
"mtkernels",
Expand Down
4 changes: 2 additions & 2 deletions gpax/models/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X if X.ndim > 1 else X[:, None]
if y is not None:
y = y[:, None] if y.ndim < 1 else y
y = y[:, None] if y.ndim < 2 else y
return X, y
return X

Expand All @@ -47,7 +47,7 @@ def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarra

def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Normal(
b = numpyro.sample(name=name, fn=dist.Cauchy(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b

Expand Down
34 changes: 29 additions & 5 deletions gpax/models/spm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import jax
import jaxlib
import jax.numpy as jnp
import jax.random as jra
from jax import vmap
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median
Expand Down Expand Up @@ -144,19 +146,44 @@ def sample_from_prior(self, rng_key: jnp.ndarray,
prior_predictive = Predictive(self.model, num_samples=num_samples)
samples = prior_predictive(rng_key, X)
return samples['y']

def sample_single_posterior_predictive(self, rng_key, X_new, params, n_draws):
sigma = params["noise"]
loc = self._model(X_new, params)
sample = dist.Normal(loc, sigma).sample(rng_key, (n_draws,)).mean(0)
return loc, sample

def _vmap_predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n_draws: int = 1,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Helper method to vectorize predictions over posterior samples
"""
if samples is None:
samples = self.get_samples(chain_dim=False)
num_samples = len(next(iter(samples.values())))
vmap_args = (jra.split(rng_key, num_samples), samples)

predictive = lambda p1, p2: self.sample_single_posterior_predictive(p1, X_new, p2, n_draws)
loc, f_samples = vmap(predictive)(*vmap_args)

return loc, f_samples

def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1,
filter_nans: bool = False, take_point_predictions_mean: bool = True,
device: Type[jaxlib.xla_extension.Device] = None
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new points using sampled GP hyperparameters
Make prediction at X_new points using posterior model parameters
Args:
rng_key: random number generator key
X_new: 2D vector with new/'test' data of :math:`n x num_features` dimensionality
samples: optional posterior samples
n: number of samples to draw from normal distribution per single HMC sample
filter_nans: filter out samples containing NaN values (if any)
take_point_predictions_mean: take a mean of point predictions (without sampling from the normal distribution)
device:
Expand All @@ -172,10 +199,7 @@ def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
if device:
X_new = jax.device_put(X_new, device)
samples = jax.device_put(samples, device)
predictive = Predictive(
self.model, posterior_samples=samples, parallel=True)
y_pred = predictive(rng_key, X_new)
y_pred, y_sampled = y_pred["mu"], y_pred["y"]
y_pred, y_sampled = self._vmap_predict(rng_key, X_new, samples, n)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
y_sampled = jnp.array(y_sampled_)
Expand Down
1 change: 1 addition & 0 deletions gpax/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .priors import *
Loading

0 comments on commit 3b7ce70

Please sign in to comment.