Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ jobs:
tests/sampling/test_deterministic.py
tests/sampling/test_forward.py
tests/sampling/test_population.py
tests/sampling/test_background_sampling.py
tests/stats/test_convergence.py
tests/stats/test_log_density.py
tests/distributions/test_distribution.py
Expand Down Expand Up @@ -190,7 +191,7 @@ jobs:
python-version: ["3.11"]
test-subset:
- tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py
- tests/model/test_core.py tests/sampling/test_mcmc.py
- tests/model/test_core.py tests/sampling/test_mcmc.py tests/sampling/test_background_sampling.py
- tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py

Expand Down Expand Up @@ -247,6 +248,7 @@ jobs:

- |
tests/sampling/test_mcmc.py
tests/sampling/test_background_sampling.py

- |
tests/backends/test_arviz.py
Expand Down Expand Up @@ -347,7 +349,7 @@ jobs:
floatx: [float32]
python-version: ["3.13"]
test-subset:
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py tests/sampling/test_background_sampling.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
144 changes: 91 additions & 53 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import pickle
import sys
import threading
import time
import warnings

Expand Down Expand Up @@ -91,6 +92,44 @@
Step: TypeAlias = BlockedStep | CompoundStep


class BackgroundSampleHandle:
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BackgroundSampleHandle class is missing a docstring. As a public-facing API component (returned by pm.sample when background=True), it should have comprehensive documentation explaining:

  • The purpose of the class
  • How to use the done(), result(), and exception() methods
  • Parameters for the __init__ method
  • Return types and behavior of each method
  • Example usage
Suggested change
class BackgroundSampleHandle:
class BackgroundSampleHandle:
"""
Handle for background sampling threads returned by `pm.sample(background=True)`.
This class provides an interface to monitor and retrieve the result of a background
sampling process started by PyMC's `pm.sample` function with `background=True`.
It mimics the interface of `concurrent.futures.Future` for convenience.
Parameters
----------
target : callable
The function to execute in the background thread.
args : tuple, optional
Positional arguments to pass to the target function.
kwargs : dict, optional
Keyword arguments to pass to the target function.
Methods
-------
start()
Start the background thread.
done() -> bool
Return True if the background sampling has finished.
result(timeout=None)
Wait for the background sampling to finish and return the result.
Raises any exception raised by the target function.
exception(timeout=None)
Wait for the background sampling to finish and return the exception, if any.
Example
-------
>>> handle = pm.sample(..., background=True)
>>> # Do other work while sampling runs in the background
>>> if handle.done():
... trace = handle.result()
... else:
... print("Sampling still in progress")
"""

Copilot uses AI. Check for mistakes.
def __init__(self, target, args=None, kwargs=None):
self._done = threading.Event()
self._result = None
self._exception = None
args = args or ()
kwargs = kwargs or {}

def runner():
try:
self._result = target(*args, **kwargs)
except Exception as exc:
self._exception = exc
finally:
self._done.set()

self._thread = threading.Thread(target=runner, daemon=True)

def start(self):
self._thread.start()
return self

def done(self):
return self._done.is_set()

def result(self, timeout=None):
self._thread.join(timeout=timeout)
if not self._done.is_set():
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential race condition: The result() method checks self._done.is_set() after join() returns. However, if a timeout is specified and expires, the thread may not have finished. In this case, join() returns without the thread being done, and the check correctly raises TimeoutError. But between the join() and the check, the thread could theoretically complete and set _done, making the check unreliable. Consider checking the join return value or using a more robust pattern like checking if the thread is still alive: if self._thread.is_alive(): raise TimeoutError(...)

Suggested change
if not self._done.is_set():
if self._thread.is_alive():

Copilot uses AI. Check for mistakes.
raise TimeoutError("Background sampling not finished yet")
if self._exception:
raise self._exception
return self._result

def exception(self, timeout=None):
self._thread.join(timeout=timeout)
return self._exception


class SamplingIteratorCallback(Protocol):
"""Signature of the callable that may be passed to `pm.sample(callable=...)`."""

Expand Down Expand Up @@ -439,6 +478,7 @@ def sample(
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function overloads don't account for background=True returning a BackgroundSampleHandle. The overloads should have additional signatures to specify that when background=True, the return type is BackgroundSampleHandle, not InferenceData or MultiTrace. For example:

@overload
def sample(
    ...,
    background: Literal[True],
    **kwargs,
) -> BackgroundSampleHandle: ...

Without this, type checkers won't correctly infer the return type when using background mode.

Suggested change
compile_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...
@overload
def sample(
draws: int = 1000,
*,
tune: int = 1000,
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: TraceOrBackend | None = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
return_inferencedata: Literal[False],
idata_kwargs: dict[str, Any] | None = None,
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...
@overload
def sample(
draws: int = 1000,
*,
tune: int = 1000,
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: TraceOrBackend | None = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
return_inferencedata: Literal[True] = True,
idata_kwargs: dict[str, Any] | None = None,
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,

Copilot uses AI. Check for mistakes.
background: bool = False,
**kwargs,
) -> InferenceData: ...

Expand Down Expand Up @@ -472,6 +512,7 @@ def sample(
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the first overload: this overload doesn't account for background=True returning a BackgroundSampleHandle. Additional overloads are needed for the background mode case.

Suggested change
compile_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...
@overload
def sample(
draws: int = 1000,
*,
tune: int = 1000,
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: TraceOrBackend | None = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
return_inferencedata: Literal[False],
idata_kwargs: dict[str, Any] | None = None,
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,

Copilot uses AI. Check for mistakes.
background: bool = False,
**kwargs,
) -> MultiTrace: ...

Expand Down Expand Up @@ -504,6 +545,8 @@ def sample(
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
compile_kwargs: dict | None = None,
background: bool = False,
_background_internal: bool = False,
**kwargs,
) -> InferenceData | MultiTrace | ZarrTrace:
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main function signature's return type should include BackgroundSampleHandle as a possible return type: -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle

Suggested change
) -> InferenceData | MultiTrace | ZarrTrace:
) -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle:

Copilot uses AI. Check for mistakes.
r"""Draw samples from the posterior using the given step methods.
Expand Down Expand Up @@ -540,7 +583,7 @@ def sample(
- "combined": A single progress bar that displays the total progress across all chains. Only timing
information is shown.
- "split": A separate progress bar for each chain. Only timing information is shown.
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across
chains. Aggregate sample statistics are also displayed.
- "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
are also displayed.
Expand Down Expand Up @@ -618,7 +661,14 @@ def sample(
Model to sample from. The model needs to have free random variables.
compile_kwargs: dict, optional
Dictionary with keyword argument to pass to the functions compiled by the step methods.
You can find a full list of arguments in the docstring of the step methods.
Comment on lines 620 to +664
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence appears to be an incomplete addition. The original docstring already had "Dictionary with keyword argument to pass to the functions compiled by the step methods." Adding "You can find a full list of arguments in the docstring of the step methods." on a new line makes it seem disconnected. Consider either merging this into a single paragraph or ensuring proper formatting.

Suggested change
Dictionary with keyword argument to pass to the functions compiled by the step methods.
You can find a full list of arguments in the docstring of the step methods.
Dictionary with keyword arguments to pass to the functions compiled by the step methods. You can find a full list of arguments in the docstring of the step methods.

Copilot uses AI. Check for mistakes.

Background mode
----------------
- Set ``background=True`` to run sampling in a background thread; this returns a handle.
- The handle supports ``done()``, ``result()``, and ``exception()``.
- Progress bars are suppressed in background mode.
- Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``.

Returns
-------
Comment on lines 665 to 674
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The background parameter is missing from the "Parameters" section of the docstring. It should be documented there with a description like:

background : bool, default=False
    If True, run sampling in a background thread and return a BackgroundSampleHandle 
    instead of the usual trace/InferenceData. The handle provides `done()`, `result()`, 
    and `exception()` methods. Currently only supported with `nuts_sampler="pymc"`.

Additionally, the "Background mode" section should be moved to a "Notes" section or integrated into the parameter documentation, as parameters sections typically document the parameters themselves.

Suggested change
Background mode
----------------
- Set ``background=True`` to run sampling in a background thread; this returns a handle.
- The handle supports ``done()``, ``result()``, and ``exception()``.
- Progress bars are suppressed in background mode.
- Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``.
Returns
-------
background : bool, default=False
If True, run sampling in a background thread and return a BackgroundSampleHandle
instead of the usual trace/InferenceData. The handle provides ``done()``, ``result()``,
and ``exception()`` methods. Currently only supported with ``nuts_sampler="pymc"``.
Notes
-----
When ``background=True``, sampling is run in a background thread and a handle is returned.
The handle supports ``done()``, ``result()``, and ``exception()`` methods.
Progress bars are suppressed in background mode.
Currently, background mode is limited to ``nuts_sampler="pymc"``; other samplers will raise
``NotImplementedError``.

Copilot uses AI. Check for mistakes.
Expand All @@ -629,65 +679,53 @@ def sample(
``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
the benefits this backend provides.

Notes
-----
Optional keyword arguments can be passed to ``sample`` to be delivered to the
``step_method``\ s used during sampling.

For example:

1. ``target_accept`` to NUTS: nuts={'target_accept':0.9}
2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7}

Note that available step names are:

``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``,
``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``,
``DEMetropolis``, ``DEMetropolisZ``, ``slice``

The NUTS step method has several options including:

* target_accept : float in [0, 1]. The step size is tuned such that we
approximate this acceptance rate. Higher values like 0.9 or 0.95 often
work better for problematic posteriors. This argument can be passed directly to sample.
* max_treedepth : The maximum depth of the trajectory tree
* step_scale : float, default 0.25
The initial guess for the step size scaled down by :math:`1/n**(1/4)`,
where n is the dimensionality of the parameter space

Alternatively, if you manually declare the ``step_method``\ s, within the ``step``
kwarg, then you can address the ``step_method`` kwargs directly.
e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis,
you could send ::

step = [
pm.NUTS([freeRV1, freeRV2], target_accept=0.9),
pm.BinaryGibbsMetropolis([freeRV3], transit_p=0.7),
]

You can find a full list of arguments in the docstring of the step methods.

Examples
--------
.. code-block:: ipython
"""
if background and not _background_internal:
if nuts_sampler != "pymc":
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The progressbar is forcibly set to False when background=True, overriding the user's explicit progressbar parameter. This could be surprising if a user passes progressbar=True and expects some output. Consider instead checking if the user explicitly set progressbar and either warning them or raising an error if they try to use progressbar with background mode, rather than silently overriding it.

Suggested change
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
# If the user explicitly requested progressbar=True, raise an error.
if (
("progressbar" in kwargs and kwargs["progressbar"] is True)
or (locals().get("progressbar", None) is True)
):
raise ValueError("Progress bars are not supported in background mode (background=True). Please set progressbar=False or omit it.")

Copilot uses AI. Check for mistakes.
progressbar = False

In [1]: import pymc as pm
...: n = 100
...: h = 61
...: alpha = 2
...: beta = 2
# Resolve the model now so the background thread has a concrete model object.
resolved_model = modelcontext(model)

In [2]: with pm.Model() as model: # context management
...: p = pm.Beta("p", alpha=alpha, beta=beta)
...: y = pm.Binomial("y", n=n, p=p, observed=h)
...: idata = pm.sample()
def _run():
return sample(
draws=draws,
tune=tune,
chains=chains,
cores=cores,
random_seed=random_seed,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
step=step,
var_names=var_names,
nuts_sampler=nuts_sampler,
initvals=initvals,
init=init,
jitter_max_retries=jitter_max_retries,
n_init=n_init,
trace=trace,
discard_tuned_samples=discard_tuned_samples,
compute_convergence_checks=compute_convergence_checks,
keep_warning_stat=keep_warning_stat,
return_inferencedata=return_inferencedata,
idata_kwargs=idata_kwargs,
nuts_sampler_kwargs=nuts_sampler_kwargs,
callback=callback,
mp_ctx=mp_ctx,
blas_cores=blas_cores,
model=resolved_model,
compile_kwargs=compile_kwargs,
background=False,
_background_internal=True,
Comment on lines +686 to +723
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _background_internal flag is used to prevent infinite recursion when calling sample() from within the background thread. However, this pattern could be simplified. Consider using a dedicated internal function (e.g., _sample_impl) that contains the main logic, and have both the background mode and regular mode call that function. This would eliminate the need for the _background_internal flag and make the code clearer.

Copilot uses AI. Check for mistakes.
**kwargs,
)

In [3]: az.summary(idata, kind="stats")
return BackgroundSampleHandle(target=_run).start()

Out[3]:
mean sd hdi_3% hdi_97%
p 0.609 0.047 0.528 0.699
"""
if "start" in kwargs:
if initvals is not None:
raise ValueError("Passing both `start` and `initvals` is not supported.")
Expand Down
47 changes: 47 additions & 0 deletions tests/sampling/test_background_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import pymc as pm


def test_background_sampling_happy_path():
with pm.Model():
pm.Normal("x", 0, 1)
handle = pm.sample(
draws=20,
tune=10,
chains=1,
cores=1,
background=True,
progressbar=False,
)
idata = handle.result()
assert hasattr(idata, "posterior")
assert idata.posterior.sizes["chain"] >= 1


def test_background_sampling_raises():
with pm.Model():
pm.Normal("x", 0, sigma=-1)
handle = pm.sample(
draws=10,
tune=5,
chains=1,
cores=1,
background=True,
progressbar=False,
)
with pytest.raises(Exception):
handle.result()