Skip to content

Commit

Permalink
generalize sigma usage
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Feb 28, 2024
1 parent c29c38f commit 948b1c3
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Change logging setup to configuration dict
- allow setting `active_dims` for composit kernels, allowing more flexible covariance kernel specifications
- update jaxconfig impot for compatibility with newer jax versions
- generalize variing sigma in FunctionEstimator for higher dimensional functions

# v1.4.1

Expand Down
28 changes: 22 additions & 6 deletions mellon/conditional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from jax.numpy import dot, square, isnan, any, eye
from jax import vmap
from jax.numpy import dot, square, isnan, any, eye, zeros, arange, ndim
from jax.numpy import sum as arraysum
from jax.numpy import diag as diagonal
from jax.numpy.linalg import cholesky
Expand Down Expand Up @@ -56,11 +57,26 @@ def _sigma_to_y_cov_factor(sigma, y_cov_factor, n):
"One can specify either `sigma` or `y_cov_factor` to describe input noise, but not both."
)

if y_cov_factor is None:
try:
y_cov_factor = diagonal(sigma)
except ValueError:
y_cov_factor = eye(n) * sigma
if y_cov_factor is not None:
return y_cov_factor

sigma_ndim = ndim(sigma)
if sigma_ndim == 0:
y_cov_factor = eye(n) * sigma
elif sigma_ndim == 1:
y_cov_factor = diagonal(sigma)
elif sigma_ndim > 1:
# Extend sigma to higher dimensions, adding a leading dimension for the diagonal
y_cov_factor = zeros((n,) + sigma.shape)

def update_diag(i, ycf, val):
return ycf.at[i, ...].set(val)

y_cov_factor = vmap(update_diag, in_axes=(0, 0, 0), out_axes=0)(
arange(n), y_cov_factor, sigma
)
else:
raise ValueError(f"Unsupported `sigma` dimensions `{sigma_ndim}`.")

return y_cov_factor

Expand Down
2 changes: 1 addition & 1 deletion mellon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def set_verbosity(verbose: bool):
verbose : bool
If True, sets the logging level to INFO to show detailed logs.
If False, sets it to WARNING to show only important messages.
Notes
-----
This function provides a simplified interface for controlling logging
Expand Down
40 changes: 40 additions & 0 deletions tests/test_sigma_to_y_cov_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import jax.numpy as jnp
import pytest
from mellon.conditional import _sigma_to_y_cov_factor


def test_scalar_sigma():
sigma = 0.5
n = 3
expected = jnp.eye(n) * sigma
result = _sigma_to_y_cov_factor(sigma, None, n)
assert jnp.allclose(result, expected), "Failed for scalar sigma"


def test_vector_sigma():
sigma = jnp.array([1.0, 2.0, 3.0])
n = 3
expected = jnp.diag(sigma)
result = _sigma_to_y_cov_factor(sigma, None, n)
assert jnp.allclose(result, expected), "Failed for vector sigma"


def test_higher_dimensional_sigma():
sigma = jnp.array([[1.0, 2.0], [3.0, 4.0]])
n = 2
expected = jnp.array([[[1.0, 2.0], [0.0, 0.0]], [[0.0, 0.0], [3.0, 4.0]]])
result = _sigma_to_y_cov_factor(sigma, None, n)
assert jnp.allclose(result, expected), "Failed for higher-dimensional sigma"


def test_both_sigma_y_cov_factor_provided():
sigma = jnp.array([1.0, 2.0, 3.0])
y_cov_factor = jnp.eye(3)
n = 3
with pytest.raises(ValueError):
_sigma_to_y_cov_factor(sigma, y_cov_factor, n)


def test_neither_sigma_nor_y_cov_factor_provided():
with pytest.raises(ValueError):
_sigma_to_y_cov_factor(None, None, 3)

0 comments on commit 948b1c3

Please sign in to comment.