diff --git a/.github/workflows/ci-self-hosted.yaml b/.github/workflows/ci-self-hosted.yaml index b489e0c..1d7ee8b 100644 --- a/.github/workflows/ci-self-hosted.yaml +++ b/.github/workflows/ci-self-hosted.yaml @@ -26,8 +26,6 @@ jobs: - uses: actions/checkout@v2 - name: Install Conda environment from environment.yml uses: mamba-org/provision-with-micromamba@main - with: - cache-env: true - name: Check CUDA shell: bash -l {0} run: | diff --git a/preclinpack/blocks/gp.py b/preclinpack/blocks/gp.py index fc04925..ae49b5b 100644 --- a/preclinpack/blocks/gp.py +++ b/preclinpack/blocks/gp.py @@ -20,7 +20,12 @@ from preclinpack.nested_hierarchy_utils import slice_data_by_nested_hierarchies from preclinpack.utils import CenteredNormal, dims_to_tuple, make_sum_zero_hh -__all__ = ["WithinBlockGP", "WithinBlockCorrelatedGP"] +__all__ = [ + "WithinBlockGP", + "WithinBlockCorrelatedGP", + "WithinBlockLKJCorrelatedGP", + "WithinBlockCorrelatedGPBase", +] def make_centered_gp_eigendecomp( @@ -628,7 +633,7 @@ def plot_gp( return plots -class WithinBlockCorrelatedGP(WithinBlockGP): +class WithinBlockCorrelatedGPBase(WithinBlockGP): default_config_key = "within_block_correlated_gp" @classmethod @@ -775,6 +780,12 @@ def make_hierarchical_gps( return leaf_rvs + @classmethod + def build_correlator_matrix(cls, dim, name_prefix="", **kwargs): + raise NotImplementedError + + +class WithinBlockCorrelatedGP(WithinBlockCorrelatedGPBase): @classmethod def build_correlator_matrix(cls, dim, name_prefix="", alpha=1.0): model = pm.modelcontext(None) @@ -793,3 +804,24 @@ def build_correlator_matrix(cls, dim, name_prefix="", alpha=1.0): R = at.set_subtensor(R[inds_l], r[inds_l[1]]) R = at.set_subtensor(R[inds_u], r[inds_u[1] - 1]) return R + + +class WithinBlockLKJCorrelatedGP(WithinBlockCorrelatedGPBase): + @staticmethod + def pack_corr(corr_vals, n): + idx_upper = np.triu_indices(n, 1) + corr = at.ones((n, n)) + corr = at.set_subtensor(corr[idx_upper], corr_vals) + corr = at.set_subtensor(corr[idx_upper[::-1]], corr_vals) + return corr + + @classmethod + def chol_corr(cls, corr_vals, n): + return at.linalg.cholesky(cls.pack_corr(corr_vals, n)) + + @classmethod + def build_correlator_matrix(cls, dim, name_prefix="", eta=1): + model = pm.modelcontext(None) + N_correl = len(model.coords[dim]) + corr_vals = pm.LKJCorr(f"{name_prefix}correlator", N_correl, eta) + return cls.chol_corr(corr_vals, N_correl) diff --git a/preclinpack/jax_utils.py b/preclinpack/jax_utils.py index 7aedb24..408722e 100644 --- a/preclinpack/jax_utils.py +++ b/preclinpack/jax_utils.py @@ -8,8 +8,10 @@ if _jax_available: from aesara.graph import Constant from aesara.link.jax.dispatch import jax_funcify + from aesara.tensor.extra_ops import FillDiagonal from aesara.tensor.shape import Reshape from jax import numpy as jnp + from pymc.distributions.multivariate import PosDefMatrix from preclinpack.blocks.distributions import BallBackwardOp @@ -39,3 +41,18 @@ def f(value, *inputs): return jnp.where(radius < 1e-6, value, value / radius * jnp.arctan(radius)) * 2 / jnp.pi return f + + @jax_funcify.register(FillDiagonal) + def _jax_funcify_FillDiagonal(op, node, **kwargs): + def f(value, diagonal): + i, j = jnp.diag_indices(min(value.shape[-2:])) + return value.at[..., i, j].set(diagonal) + + return f + + @jax_funcify.register(PosDefMatrix) + def _jax_funcify_PosDefMatrix(op, node, **kwargs): + def f(m): + return ~jnp.isnan(jnp.linalg.cholesky(m)).any() + + return f diff --git a/preclinpack/models/gp.py b/preclinpack/models/gp.py index 98980a0..2f41d54 100644 --- a/preclinpack/models/gp.py +++ b/preclinpack/models/gp.py @@ -1,5 +1,5 @@ from preclinpack.blocks.base import SortNestedHierarchies -from preclinpack.blocks.gp import WithinBlockCorrelatedGP, WithinBlockGP +from preclinpack.blocks.gp import WithinBlockGP, WithinBlockLKJCorrelatedGP from preclinpack.blocks.likelihoods import NegBinLikelihood, ZeroInfNegBinLikelihood from preclinpack.blocks.linear import ( BaselineActivity, @@ -45,7 +45,7 @@ def NegBinFullCorrelatedModel(data, config, **kwargs): hierarchies=SortNestedHierarchies, baseline_activity=BaselineActivity, linear_long_term_trend=LinearLongTermTrend, - within_block_correlated_gp=WithinBlockCorrelatedGP, + within_block_correlated_gp=WithinBlockLKJCorrelatedGP, likelihood=NegBinLikelihood, ) return Model(blocks_config=blocks_config, data=data, config=config, **kwargs) @@ -56,7 +56,7 @@ def ZeroInfNegBinFullCorrelatedModel(data, config, **kwargs): hierarchies=SortNestedHierarchies, baseline_activity=BaselineActivity, linear_long_term_trend=LinearLongTermTrend, - within_block_correlated_gp=WithinBlockCorrelatedGP, + within_block_correlated_gp=WithinBlockLKJCorrelatedGP, likelihood=ZeroInfNegBinLikelihood, ) return Model(blocks_config=blocks_config, data=data, config=config, **kwargs) diff --git a/preclinpack/tests/blocks/test_gp.py b/preclinpack/tests/blocks/test_gp.py index f812a50..99b643b 100644 --- a/preclinpack/tests/blocks/test_gp.py +++ b/preclinpack/tests/blocks/test_gp.py @@ -14,7 +14,9 @@ from preclinpack.blocks.distributions import HyperballUniformRV from preclinpack.blocks.gp import ( WithinBlockCorrelatedGP, + WithinBlockCorrelatedGPBase, WithinBlockGP, + WithinBlockLKJCorrelatedGP, make_centered_gp_eigendecomp, make_covariance_kernel, make_hierarchical_gps_eigenfunctions, @@ -276,7 +278,9 @@ def test_repeated_coords(self): ) -@pytest.mark.parametrize("block", (WithinBlockGP, WithinBlockCorrelatedGP)) +@pytest.mark.parametrize( + "block", (WithinBlockGP, WithinBlockCorrelatedGP, WithinBlockLKJCorrelatedGP) +) class TestWithinBlockGPMakeHierarchicalGPs: @classmethod def setup_class(cls): @@ -515,7 +519,9 @@ def test_build_correlator_matrix(dims): np.testing.assert_allclose(r_eval[row], expected_row) -@pytest.mark.parametrize("block", (WithinBlockGP, WithinBlockCorrelatedGP)) +@pytest.mark.parametrize( + "block", (WithinBlockGP, WithinBlockCorrelatedGP, WithinBlockLKJCorrelatedGP) +) @pytest.mark.parametrize( "nested_hierarchies", ( @@ -527,7 +533,7 @@ def test_build_correlator_matrix(dims): @pytest.mark.parametrize("use_advanced_indexing", (False, True)) def test_make_latent_rvs(block, nested_hierarchies, coef_dims, use_advanced_indexing): # Get out with invalid parametrization - if block is WithinBlockCorrelatedGP and coef_dims is None: + if issubclass(block, WithinBlockCorrelatedGPBase) and coef_dims is None: return hierarchy_coords, hierarchy = slice_data_by_nested_hierarchies( diff --git a/preclinpack/tests/models/test_models.py b/preclinpack/tests/models/test_models.py index 7c6c01c..c1a903d 100644 --- a/preclinpack/tests/models/test_models.py +++ b/preclinpack/tests/models/test_models.py @@ -9,8 +9,8 @@ NegBinLikelihood, QuadraticLongTermTrend, SortNestedHierarchies, - WithinBlockCorrelatedGP, WithinBlockGP, + WithinBlockLKJCorrelatedGP, ZeroInfNegBinLikelihood, ) from preclinpack.models import ( @@ -69,7 +69,7 @@ hierarchies=SortNestedHierarchies, baseline_activity=BaselineActivity, linear_long_term_trend=LinearLongTermTrend, - within_block_correlated_gp=WithinBlockCorrelatedGP, + within_block_correlated_gp=WithinBlockLKJCorrelatedGP, likelihood=NegBinLikelihood, ), lambda config: config["likelihood"].update({"concentration_prior": (1,)}), @@ -80,7 +80,7 @@ hierarchies=SortNestedHierarchies, baseline_activity=BaselineActivity, linear_long_term_trend=LinearLongTermTrend, - within_block_correlated_gp=WithinBlockCorrelatedGP, + within_block_correlated_gp=WithinBlockLKJCorrelatedGP, likelihood=ZeroInfNegBinLikelihood, ), lambda config: config["likelihood"].update(