Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- tests/statespace/filters/test_kalman_filter.py
- tests/statespace --ignore tests/statespace/core/test_statespace.py --ignore tests/statespace/filters/test_kalman_filter.py
- tests/distributions
- tests --ignore tests/model --ignore tests/statespace --ignore tests/distributions
- tests --ignore tests/model --ignore tests/statespace --ignore tests/distributions --ignore tests/pathfinder
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
6 changes: 3 additions & 3 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
name: pymc-extras-test
name: pymc-extras
channels:
- conda-forge
- nodefaults
dependencies:
- pymc>=5.24.1
- pytensor>=2.31.4
- pymc>=5.26.1
- pytensor>=2.35.1
- scikit-learn
- better-optimize>=0.1.5
- dask<2025.1.1
Expand Down
23 changes: 14 additions & 9 deletions pymc_extras/inference/laplace_approx/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def find_MAP(
jitter_rvs: list[TensorVariable] | None = None,
progressbar: bool = True,
include_transformed: bool = True,
freeze_model: bool = True,
gradient_backend: GradientBackend = "pytensor",
compile_kwargs: dict | None = None,
compute_hessian: bool = False,
Expand Down Expand Up @@ -210,6 +211,10 @@ def find_MAP(
Whether to display a progress bar during optimization. Defaults to True.
include_transformed: bool, optional
Whether to include transformed variable values in the returned dictionary. Defaults to True.
freeze_model: bool, optional
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
True.
gradient_backend: str, default "pytensor"
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
compute_hessian: bool
Expand All @@ -229,11 +234,13 @@ def find_MAP(
Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
latent variables, and optimizer results.
"""
model = pm.modelcontext(model) if model is None else model
frozen_model = freeze_dims_and_data(model)
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
model = pm.modelcontext(model) if model is None else model

initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
if freeze_model:
model = freeze_dims_and_data(model)

initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)

do_basinhopping = method == "basinhopping"
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
Expand All @@ -251,8 +258,8 @@ def find_MAP(
)

f_fused, f_hessp = scipy_optimize_funcs_from_loss(
loss=-frozen_model.logp(),
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
loss=-model.logp(),
inputs=model.continuous_value_vars + model.discrete_value_vars,
initial_point_dict=DictToArrayBijection.rmap(initial_params),
use_grad=use_grad,
use_hess=use_hess,
Expand Down Expand Up @@ -316,12 +323,10 @@ def find_MAP(
}

idata = map_results_to_inference_data(
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
map_point=optimized_point, model=model, include_transformed=include_transformed
)

idata = add_fit_to_inference_data(
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
)
idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)

idata = add_optimizer_result_to_inference_data(
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
Expand Down
27 changes: 22 additions & 5 deletions pymc_extras/inference/laplace_approx/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,13 @@ def _unconstrained_vector_to_constrained_rvs(model):
unconstrained_vector.name = "unconstrained_vector"

# Redo the names list to ensure it is sorted to match the return order
names = [*constrained_names, *unconstrained_names]
constrained_rvs_and_names = [(rv, name) for rv, name in zip(constrained_rvs, constrained_names)]
value_rvs_and_names = [
(rv, name) for rv, name in zip(value_rvs, names) for name in unconstrained_names
]
# names = [*constrained_names, *unconstrained_names]

return names, constrained_rvs, value_rvs, unconstrained_vector
return constrained_rvs_and_names, value_rvs_and_names, unconstrained_vector


def model_to_laplace_approx(
Expand All @@ -182,8 +186,11 @@ def model_to_laplace_approx(

# temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
# so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
names, constrained_rvs, value_rvs, unconstrained_vector = (
_unconstrained_vector_to_constrained_rvs(model)

# The model was frozen during the find_MAP procedure. To ensure we're operating on the same model, freeze it again.
frozen_model = freeze_dims_and_data(model)
constrained_rvs_and_names, _, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(
frozen_model
)

coords = model.coords | {
Expand All @@ -204,12 +211,13 @@ def model_to_laplace_approx(
)

cast_to_var = partial(type_cast, Variable)
constrained_rvs, constrained_names = zip(*constrained_rvs_and_names)
batched_rvs = vectorize_graph(
type_cast(list[Variable], constrained_rvs),
replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
)

for name, batched_rv in zip(names, batched_rvs):
for name, batched_rv in zip(constrained_names, batched_rvs):
batch_dims = ("temp_chain", "temp_draw")
if batched_rv.ndim == 2:
dims = batch_dims
Expand Down Expand Up @@ -285,6 +293,7 @@ def fit_laplace(
jitter_rvs: list[pt.TensorVariable] | None = None,
progressbar: bool = True,
include_transformed: bool = True,
freeze_model: bool = True,
gradient_backend: GradientBackend = "pytensor",
chains: int = 2,
draws: int = 500,
Expand Down Expand Up @@ -328,6 +337,10 @@ def fit_laplace(
include_transformed: bool, default True
Whether to include transformed variables in the output. If True, transformed variables will be included in the
output InferenceData object. If False, only the original variables will be included.
freeze_model: bool, optional
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
True.
gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
chains: int, default: 2
Expand Down Expand Up @@ -376,6 +389,9 @@ def fit_laplace(
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
model = pm.modelcontext(model) if model is None else model

if freeze_model:
model = freeze_dims_and_data(model)

idata = find_MAP(
method=optimize_method,
model=model,
Expand All @@ -387,6 +403,7 @@ def fit_laplace(
jitter_rvs=jitter_rvs,
progressbar=progressbar,
include_transformed=include_transformed,
freeze_model=False,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
compute_hessian=True,
Expand Down
7 changes: 2 additions & 5 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass, field, replace
from enum import Enum, auto
from typing import Literal, TypeAlias
from typing import Literal, Self, TypeAlias

import arviz as az
import filelock
Expand Down Expand Up @@ -60,9 +60,6 @@
from rich.table import Table
from rich.text import Text

# TODO: change to typing.Self after Python versions greater than 3.10
from typing_extensions import Self

from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
from pymc_extras.inference.pathfinder.importance_sampling import (
importance_sampling as _importance_sampling,
Expand Down Expand Up @@ -533,7 +530,7 @@ def bfgs_sample_sparse(

# qr_input: (L, N, 2J)
qr_input = inv_sqrt_alpha_diag @ beta
(Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
(Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)

IdN = pt.eye(R.shape[1])[None, ...]
IdN += IdN * REGULARISATION_TERM
Expand Down
4 changes: 2 additions & 2 deletions pymc_extras/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pymc import SymbolicRandomVariable
from pymc.model.fgraph import ModelVar
from pymc.variational.minibatch_rv import MinibatchRandomVariable
from pytensor.graph import Variable, ancestors
from pytensor.graph.basic import io_toposort
from pytensor.graph.basic import Variable
from pytensor.graph.traversal import ancestors, io_toposort
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/statespace/core/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def compile_statespace(
x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
)

inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
inputs = list(pytensor.graph.traversal.explicit_graph_inputs(outputs))

_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)

Expand Down
4 changes: 2 additions & 2 deletions pymc_extras/statespace/filters/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def build_graph(
self.n_endog = Z_shape[-2]

data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)

data = pt.specify_shape(data, (data.type.shape[0], self.n_endog))
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
params, PARAM_NAMES
)
Expand Down Expand Up @@ -658,7 +658,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
# [0, L_pred]]
# The Schur decomposition of this matrix will be B (upper triangular). We are
# more insterested in B^T:
# more interested in B^T:
# Structure of B^T = [[chol(F), 0 ],
# [K @ chol(F), chol(P_filtered)]
zeros = pt.zeros((self.n_states, self.n_endog))
Expand Down
4 changes: 1 addition & 3 deletions pymc_extras/statespace/filters/kalman_smoother.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytensor
import pytensor.tensor as pt

from pytensor.tensor.nlinalg import matrix_dot

from pymc_extras.statespace.filters.utilities import (
quad_form_sym,
split_vars_into_seq_and_nonseq,
Expand Down Expand Up @@ -105,7 +103,7 @@ def smoother_step(self, *args):
a_hat, P_hat = self.predict(a, P, T, R, Q)

# Use pinv, otherwise P_hat is singular when there is missing data
smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
smoother_gain = (pt.linalg.pinv(P_hat, hermitian=True) @ T @ P).mT
a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)

P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)
Expand Down
4 changes: 2 additions & 2 deletions pymc_extras/utils/model_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from pymc.model.fgraph import fgraph_from_model
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, graph_inputs
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.traversal import graph_inputs
from pytensor.tensor.random.type import RandomType


Expand Down
14 changes: 8 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ keywords = [
license = {file = "LICENSE"}
dynamic = ["version"] # specify the version in the __init__.py file
dependencies = [
"pymc>=5.24.1",
"pytensor>=2.31.4",
"pymc>=5.26.1",
"pytensor>=2.35.1",
"scikit-learn",
"better-optimize>=0.1.5",
"pydantic>=2.0.0",
Expand Down Expand Up @@ -104,9 +104,11 @@ filterwarnings =[
# JAX issues an over-eager warning if os.fork() is called when the JAX module is loaded, even if JAX isn't being used
'ignore:os\.fork\(\) was called\.:RuntimeWarning',

# Silence warning emitted by pytensor when BLAS is not available
"ignore:\\n?Found Intel OpenMP \\('libiomp'\\) and LLVM OpenMP \\('libomp'\\).*:RuntimeWarning",
'ignore:PyTensor could not link to a BLAS installation:UserWarning'
# Preliz needs to update for pytensor > 2.35
'ignore:.*`pytensor\.graph\.basic\.ancestors`.*`pytensor\.graph\.traversal\.ancestors`.*:FutureWarning:^preliz(\.|$)',

# OpenMP library warning on windows CI
'ignore::RuntimeWarning:threadpoolctl'
]

[tool.coverage.report]
Expand All @@ -118,7 +120,7 @@ exclude_lines = [

[tool.ruff]
line-length = 100
target-version = "py310"
target-version = "py311"

[tool.ruff.format]
docstring-code-format = true
Expand Down
15 changes: 8 additions & 7 deletions tests/inference/laplace_approx/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
assert idata.fit["mean_vector"].shape == (len(vars),)
assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars))

bda_map = [y.mean(), np.log(y.std())]
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
bda_map = [np.log(y.std()), y.mean()]
bda_cov = np.array([[1 / (2 * n), 0], [0, y.var() / n]])

np.testing.assert_allclose(idata.posterior["mu"].mean(), bda_map[0], atol=1)
np.testing.assert_allclose(idata.posterior["logsigma"].mean(), bda_map[1], rtol=1e-3)
np.testing.assert_allclose(idata.posterior["logsigma"].mean(), bda_map[0], rtol=1e-3)
np.testing.assert_allclose(idata.posterior["mu"].mean(), bda_map[1], atol=1)

np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map, atol=1, rtol=1e-3)
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)


Expand Down Expand Up @@ -138,12 +139,12 @@ def test_fit_laplace_coords(include_transformed, rng):
)

assert idata.fit.rows.values.tolist() == [
"mu[A]",
"mu[B]",
"mu[C]",
"sigma_log__[A]",
"sigma_log__[B]",
"sigma_log__[C]",
"mu[A]",
"mu[B]",
"mu[C]",
]

assert hasattr(idata, "unconstrained_posterior") == include_transformed
Expand Down
2 changes: 1 addition & 1 deletion tests/model/marginal/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ def test_marginalized_hmm_multiple_emissions(batch_chain, batch_emission1, batch
test_value_emission2 = np.broadcast_to(-test_value, emission2_shape)
test_point = {"emission_1": test_value_emission1, "emission_2": test_value_emission2}
res_logp, dummy_logp = logp_fn(test_point)
assert res_logp.shape == ((1, 3) if batch_chain else ())
assert res_logp.shape == ((3, 1) if batch_chain else ())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "first dependent RV" that gets the full logp changed with our new toposort algo

np.testing.assert_allclose(res_logp.sum(), expected_logp)
2 changes: 1 addition & 1 deletion tests/model/marginal/test_graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_blockwise(self):
with pytest.raises(ValueError, match="Use of known dimensions"):
subgraph_batch_dim_connection(inp, [invalid_out])

out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((2, 3))
out = (inp[:, :, None, None] + pt.zeros((2, 3))) @ pt.ones((3, 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static shape check now reveals this error in the test

[dims] = subgraph_batch_dim_connection(inp, [out])
assert dims == (0, 1, None, None)

Expand Down
6 changes: 3 additions & 3 deletions tests/model/marginal/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def build_model(build_batched: bool) -> Model:

# Test initial_point
ips = make_initial_point_expression(
free_rvs=marginal_m.free_RVs,
free_rvs=[marginal_m["sigma"], marginal_m["dep"], marginal_m["sub_dep"]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of free_RVs changed for similar reason

rvs_to_transforms=marginal_m.rvs_to_transforms,
initval_strategies={},
)
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_interdependent_rvs():

# Test initial_point
ips = make_initial_point_expression(
free_rvs=marginal_m.free_RVs,
free_rvs=[marginal_m["x"], marginal_m["y"]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

rvs_to_transforms={},
initval_strategies={},
)
Expand All @@ -306,7 +306,7 @@ def test_interdependent_rvs():
# Test custom initval strategy
ips = make_initial_point_expression(
# Test that order does not matter
free_rvs=marginal_m.free_RVs[::-1],
free_rvs=[marginal_m["y"], marginal_m["x"]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

rvs_to_transforms={},
initval_strategies={marginal_x: pt.constant(5.0)},
)
Expand Down
Empty file added tests/pathfinder/__init__.py
Empty file.
5 changes: 4 additions & 1 deletion tests/pathfinder/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def test_pathfinder(inference_backend, reference_idata):
assert idata.posterior["theta"].shape == (1, 1000, 8)


@pytest.mark.parametrize("concurrent", ["thread", "process"])
@pytest.mark.parametrize(
"concurrent",
[pytest.param("thread", marks=pytest.mark.skip(reason="CI hangs on Windows")), "process"],
)
def test_concurrent_results(reference_idata, concurrent):
model = eight_schools_model()
with model:
Expand Down
2 changes: 1 addition & 1 deletion tests/statespace/core/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from numpy.testing import assert_allclose
from pymc.testing import mock_sample_setup_and_teardown
from pytensor.compile import SharedVariable
from pytensor.graph.basic import graph_inputs
from pytensor.graph.traversal import graph_inputs

from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
from pymc_extras.statespace.models import structural as st
Expand Down
Loading