diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 37a56f8a06..df788f9729 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: tests/sampling/test_deterministic.py tests/sampling/test_forward.py tests/sampling/test_population.py + tests/sampling/test_background_sampling.py tests/stats/test_convergence.py tests/stats/test_log_density.py tests/distributions/test_distribution.py @@ -190,7 +191,7 @@ jobs: python-version: ["3.11"] test-subset: - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py - - tests/model/test_core.py tests/sampling/test_mcmc.py + - tests/model/test_core.py tests/sampling/test_mcmc.py tests/sampling/test_background_sampling.py - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py @@ -247,6 +248,7 @@ jobs: - | tests/sampling/test_mcmc.py + tests/sampling/test_background_sampling.py - | tests/backends/test_arviz.py @@ -347,7 +349,7 @@ jobs: floatx: [float32] python-version: ["3.13"] test-subset: - - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py + - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py tests/sampling/test_background_sampling.py fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..e59da7479a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -18,6 +18,7 @@ import logging import pickle import sys +import threading import time import warnings @@ -91,6 +92,44 @@ Step: TypeAlias = BlockedStep | CompoundStep +class BackgroundSampleHandle: + def __init__(self, target, args=None, kwargs=None): + self._done = threading.Event() + self._result = None + self._exception = None + args = args or () + kwargs = kwargs or {} + + def runner(): + try: + self._result = target(*args, **kwargs) + except Exception as exc: + self._exception = exc + finally: + self._done.set() + + self._thread = threading.Thread(target=runner, daemon=True) + + def start(self): + self._thread.start() + return self + + def done(self): + return self._done.is_set() + + def result(self, timeout=None): + self._thread.join(timeout=timeout) + if not self._done.is_set(): + raise TimeoutError("Background sampling not finished yet") + if self._exception: + raise self._exception + return self._result + + def exception(self, timeout=None): + self._thread.join(timeout=timeout) + return self._exception + + class SamplingIteratorCallback(Protocol): """Signature of the callable that may be passed to `pm.sample(callable=...)`.""" @@ -439,6 +478,7 @@ def sample( mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", compile_kwargs: dict | None = None, + background: bool = False, **kwargs, ) -> InferenceData: ... @@ -472,6 +512,7 @@ def sample( model: Model | None = None, blas_cores: int | None | Literal["auto"] = "auto", compile_kwargs: dict | None = None, + background: bool = False, **kwargs, ) -> MultiTrace: ... @@ -504,6 +545,8 @@ def sample( blas_cores: int | None | Literal["auto"] = "auto", model: Model | None = None, compile_kwargs: dict | None = None, + background: bool = False, + _background_internal: bool = False, **kwargs, ) -> InferenceData | MultiTrace | ZarrTrace: r"""Draw samples from the posterior using the given step methods. @@ -540,7 +583,7 @@ def sample( - "combined": A single progress bar that displays the total progress across all chains. Only timing information is shown. - "split": A separate progress bar for each chain. Only timing information is shown. - - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across chains. Aggregate sample statistics are also displayed. - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain are also displayed. @@ -618,7 +661,14 @@ def sample( Model to sample from. The model needs to have free random variables. compile_kwargs: dict, optional Dictionary with keyword argument to pass to the functions compiled by the step methods. + You can find a full list of arguments in the docstring of the step methods. + Background mode + ---------------- + - Set ``background=True`` to run sampling in a background thread; this returns a handle. + - The handle supports ``done()``, ``result()``, and ``exception()``. + - Progress bars are suppressed in background mode. + - Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``. Returns ------- @@ -629,65 +679,53 @@ def sample( ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for the benefits this backend provides. - Notes - ----- - Optional keyword arguments can be passed to ``sample`` to be delivered to the - ``step_method``\ s used during sampling. - - For example: - - 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9} - 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7} - - Note that available step names are: - - ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``, - ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``, - ``DEMetropolis``, ``DEMetropolisZ``, ``slice`` - - The NUTS step method has several options including: - - * target_accept : float in [0, 1]. The step size is tuned such that we - approximate this acceptance rate. Higher values like 0.9 or 0.95 often - work better for problematic posteriors. This argument can be passed directly to sample. - * max_treedepth : The maximum depth of the trajectory tree - * step_scale : float, default 0.25 - The initial guess for the step size scaled down by :math:`1/n**(1/4)`, - where n is the dimensionality of the parameter space - - Alternatively, if you manually declare the ``step_method``\ s, within the ``step`` - kwarg, then you can address the ``step_method`` kwargs directly. - e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, - you could send :: - - step = [ - pm.NUTS([freeRV1, freeRV2], target_accept=0.9), - pm.BinaryGibbsMetropolis([freeRV3], transit_p=0.7), - ] - - You can find a full list of arguments in the docstring of the step methods. - Examples -------- .. code-block:: ipython + """ + if background and not _background_internal: + if nuts_sampler != "pymc": + raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") + progressbar = False - In [1]: import pymc as pm - ...: n = 100 - ...: h = 61 - ...: alpha = 2 - ...: beta = 2 + # Resolve the model now so the background thread has a concrete model object. + resolved_model = modelcontext(model) - In [2]: with pm.Model() as model: # context management - ...: p = pm.Beta("p", alpha=alpha, beta=beta) - ...: y = pm.Binomial("y", n=n, p=p, observed=h) - ...: idata = pm.sample() + def _run(): + return sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + random_seed=random_seed, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + step=step, + var_names=var_names, + nuts_sampler=nuts_sampler, + initvals=initvals, + init=init, + jitter_max_retries=jitter_max_retries, + n_init=n_init, + trace=trace, + discard_tuned_samples=discard_tuned_samples, + compute_convergence_checks=compute_convergence_checks, + keep_warning_stat=keep_warning_stat, + return_inferencedata=return_inferencedata, + idata_kwargs=idata_kwargs, + nuts_sampler_kwargs=nuts_sampler_kwargs, + callback=callback, + mp_ctx=mp_ctx, + blas_cores=blas_cores, + model=resolved_model, + compile_kwargs=compile_kwargs, + background=False, + _background_internal=True, + **kwargs, + ) - In [3]: az.summary(idata, kind="stats") + return BackgroundSampleHandle(target=_run).start() - Out[3]: - mean sd hdi_3% hdi_97% - p 0.609 0.047 0.528 0.699 - """ if "start" in kwargs: if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.") diff --git a/tests/sampling/test_background_sampling.py b/tests/sampling/test_background_sampling.py new file mode 100644 index 0000000000..72510353ff --- /dev/null +++ b/tests/sampling/test_background_sampling.py @@ -0,0 +1,47 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import pymc as pm + + +def test_background_sampling_happy_path(): + with pm.Model(): + pm.Normal("x", 0, 1) + handle = pm.sample( + draws=20, + tune=10, + chains=1, + cores=1, + background=True, + progressbar=False, + ) + idata = handle.result() + assert hasattr(idata, "posterior") + assert idata.posterior.sizes["chain"] >= 1 + + +def test_background_sampling_raises(): + with pm.Model(): + pm.Normal("x", 0, sigma=-1) + handle = pm.sample( + draws=10, + tune=5, + chains=1, + cores=1, + background=True, + progressbar=False, + ) + with pytest.raises(Exception): + handle.result()