Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci-self-hosted.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ jobs:
- uses: actions/checkout@v2
- name: Install Conda environment from environment.yml
uses: mamba-org/provision-with-micromamba@main
with:
cache-env: true
- name: Check CUDA
shell: bash -l {0}
run: |
Expand Down
36 changes: 34 additions & 2 deletions preclinpack/blocks/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from preclinpack.nested_hierarchy_utils import slice_data_by_nested_hierarchies
from preclinpack.utils import CenteredNormal, dims_to_tuple, make_sum_zero_hh

__all__ = ["WithinBlockGP", "WithinBlockCorrelatedGP"]
__all__ = [
"WithinBlockGP",
"WithinBlockCorrelatedGP",
"WithinBlockLKJCorrelatedGP",
"WithinBlockCorrelatedGPBase",
]


def make_centered_gp_eigendecomp(
Expand Down Expand Up @@ -628,7 +633,7 @@ def plot_gp(
return plots


class WithinBlockCorrelatedGP(WithinBlockGP):
class WithinBlockCorrelatedGPBase(WithinBlockGP):
default_config_key = "within_block_correlated_gp"

@classmethod
Expand Down Expand Up @@ -775,6 +780,12 @@ def make_hierarchical_gps(

return leaf_rvs

@classmethod
def build_correlator_matrix(cls, dim, name_prefix="", **kwargs):
raise NotImplementedError


class WithinBlockCorrelatedGP(WithinBlockCorrelatedGPBase):
@classmethod
def build_correlator_matrix(cls, dim, name_prefix="", alpha=1.0):
model = pm.modelcontext(None)
Expand All @@ -793,3 +804,24 @@ def build_correlator_matrix(cls, dim, name_prefix="", alpha=1.0):
R = at.set_subtensor(R[inds_l], r[inds_l[1]])
R = at.set_subtensor(R[inds_u], r[inds_u[1] - 1])
return R


class WithinBlockLKJCorrelatedGP(WithinBlockCorrelatedGPBase):
@staticmethod
def pack_corr(corr_vals, n):
idx_upper = np.triu_indices(n, 1)
corr = at.ones((n, n))
corr = at.set_subtensor(corr[idx_upper], corr_vals)
corr = at.set_subtensor(corr[idx_upper[::-1]], corr_vals)
return corr

@classmethod
def chol_corr(cls, corr_vals, n):
return at.linalg.cholesky(cls.pack_corr(corr_vals, n))

@classmethod
def build_correlator_matrix(cls, dim, name_prefix="", eta=1):
model = pm.modelcontext(None)
N_correl = len(model.coords[dim])
corr_vals = pm.LKJCorr(f"{name_prefix}correlator", N_correl, eta)
return cls.chol_corr(corr_vals, N_correl)
17 changes: 17 additions & 0 deletions preclinpack/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
if _jax_available:
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.extra_ops import FillDiagonal
from aesara.tensor.shape import Reshape
from jax import numpy as jnp
from pymc.distributions.multivariate import PosDefMatrix

from preclinpack.blocks.distributions import BallBackwardOp

Expand Down Expand Up @@ -39,3 +41,18 @@ def f(value, *inputs):
return jnp.where(radius < 1e-6, value, value / radius * jnp.arctan(radius)) * 2 / jnp.pi

return f

@jax_funcify.register(FillDiagonal)
def _jax_funcify_FillDiagonal(op, node, **kwargs):
def f(value, diagonal):
i, j = jnp.diag_indices(min(value.shape[-2:]))
return value.at[..., i, j].set(diagonal)

return f

@jax_funcify.register(PosDefMatrix)
def _jax_funcify_PosDefMatrix(op, node, **kwargs):
def f(m):
return ~jnp.isnan(jnp.linalg.cholesky(m)).any()

return f
6 changes: 3 additions & 3 deletions preclinpack/models/gp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from preclinpack.blocks.base import SortNestedHierarchies
from preclinpack.blocks.gp import WithinBlockCorrelatedGP, WithinBlockGP
from preclinpack.blocks.gp import WithinBlockGP, WithinBlockLKJCorrelatedGP
from preclinpack.blocks.likelihoods import NegBinLikelihood, ZeroInfNegBinLikelihood
from preclinpack.blocks.linear import (
BaselineActivity,
Expand Down Expand Up @@ -45,7 +45,7 @@ def NegBinFullCorrelatedModel(data, config, **kwargs):
hierarchies=SortNestedHierarchies,
baseline_activity=BaselineActivity,
linear_long_term_trend=LinearLongTermTrend,
within_block_correlated_gp=WithinBlockCorrelatedGP,
within_block_correlated_gp=WithinBlockLKJCorrelatedGP,
likelihood=NegBinLikelihood,
)
return Model(blocks_config=blocks_config, data=data, config=config, **kwargs)
Expand All @@ -56,7 +56,7 @@ def ZeroInfNegBinFullCorrelatedModel(data, config, **kwargs):
hierarchies=SortNestedHierarchies,
baseline_activity=BaselineActivity,
linear_long_term_trend=LinearLongTermTrend,
within_block_correlated_gp=WithinBlockCorrelatedGP,
within_block_correlated_gp=WithinBlockLKJCorrelatedGP,
likelihood=ZeroInfNegBinLikelihood,
)
return Model(blocks_config=blocks_config, data=data, config=config, **kwargs)
12 changes: 9 additions & 3 deletions preclinpack/tests/blocks/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from preclinpack.blocks.distributions import HyperballUniformRV
from preclinpack.blocks.gp import (
WithinBlockCorrelatedGP,
WithinBlockCorrelatedGPBase,
WithinBlockGP,
WithinBlockLKJCorrelatedGP,
make_centered_gp_eigendecomp,
make_covariance_kernel,
make_hierarchical_gps_eigenfunctions,
Expand Down Expand Up @@ -276,7 +278,9 @@ def test_repeated_coords(self):
)


@pytest.mark.parametrize("block", (WithinBlockGP, WithinBlockCorrelatedGP))
@pytest.mark.parametrize(
"block", (WithinBlockGP, WithinBlockCorrelatedGP, WithinBlockLKJCorrelatedGP)
)
class TestWithinBlockGPMakeHierarchicalGPs:
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -515,7 +519,9 @@ def test_build_correlator_matrix(dims):
np.testing.assert_allclose(r_eval[row], expected_row)


@pytest.mark.parametrize("block", (WithinBlockGP, WithinBlockCorrelatedGP))
@pytest.mark.parametrize(
"block", (WithinBlockGP, WithinBlockCorrelatedGP, WithinBlockLKJCorrelatedGP)
)
@pytest.mark.parametrize(
"nested_hierarchies",
(
Expand All @@ -527,7 +533,7 @@ def test_build_correlator_matrix(dims):
@pytest.mark.parametrize("use_advanced_indexing", (False, True))
def test_make_latent_rvs(block, nested_hierarchies, coef_dims, use_advanced_indexing):
# Get out with invalid parametrization
if block is WithinBlockCorrelatedGP and coef_dims is None:
if issubclass(block, WithinBlockCorrelatedGPBase) and coef_dims is None:
return

hierarchy_coords, hierarchy = slice_data_by_nested_hierarchies(
Expand Down
6 changes: 3 additions & 3 deletions preclinpack/tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
NegBinLikelihood,
QuadraticLongTermTrend,
SortNestedHierarchies,
WithinBlockCorrelatedGP,
WithinBlockGP,
WithinBlockLKJCorrelatedGP,
ZeroInfNegBinLikelihood,
)
from preclinpack.models import (
Expand Down Expand Up @@ -69,7 +69,7 @@
hierarchies=SortNestedHierarchies,
baseline_activity=BaselineActivity,
linear_long_term_trend=LinearLongTermTrend,
within_block_correlated_gp=WithinBlockCorrelatedGP,
within_block_correlated_gp=WithinBlockLKJCorrelatedGP,
likelihood=NegBinLikelihood,
),
lambda config: config["likelihood"].update({"concentration_prior": (1,)}),
Expand All @@ -80,7 +80,7 @@
hierarchies=SortNestedHierarchies,
baseline_activity=BaselineActivity,
linear_long_term_trend=LinearLongTermTrend,
within_block_correlated_gp=WithinBlockCorrelatedGP,
within_block_correlated_gp=WithinBlockLKJCorrelatedGP,
likelihood=ZeroInfNegBinLikelihood,
),
lambda config: config["likelihood"].update(
Expand Down