-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add background sampling handle to pm.sample #7991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pickle | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -91,6 +92,44 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Step: TypeAlias = BlockedStep | CompoundStep | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class BackgroundSampleHandle: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, target, args=None, kwargs=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._done = threading.Event() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._result = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._exception = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = args or () | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = kwargs or {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def runner(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._result = target(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._exception = exc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._done.set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._thread = threading.Thread(target=runner, daemon=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def start(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._thread.start() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def done(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self._done.is_set() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def result(self, timeout=None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._thread.join(timeout=timeout) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self._done.is_set(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self._done.is_set(): | |
| if self._thread.is_alive(): |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function overloads don't account for background=True returning a BackgroundSampleHandle. The overloads should have additional signatures to specify that when background=True, the return type is BackgroundSampleHandle, not InferenceData or MultiTrace. For example:
@overload
def sample(
...,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...Without this, type checkers won't correctly infer the return type when using background mode.
| compile_kwargs: dict | None = None, | |
| compile_kwargs: dict | None = None, | |
| background: Literal[True], | |
| **kwargs, | |
| ) -> BackgroundSampleHandle: ... | |
| @overload | |
| def sample( | |
| draws: int = 1000, | |
| *, | |
| tune: int = 1000, | |
| chains: int | None = None, | |
| cores: int | None = None, | |
| random_seed: RandomState = None, | |
| progressbar: bool | ProgressBarType = True, | |
| progressbar_theme: Theme | None = default_progress_theme, | |
| step=None, | |
| var_names: Sequence[str] | None = None, | |
| nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | |
| initvals: StartDict | Sequence[StartDict | None] | None = None, | |
| init: str = "auto", | |
| jitter_max_retries: int = 10, | |
| n_init: int = 200_000, | |
| trace: TraceOrBackend | None = None, | |
| discard_tuned_samples: bool = True, | |
| compute_convergence_checks: bool = True, | |
| keep_warning_stat: bool = False, | |
| return_inferencedata: Literal[False], | |
| idata_kwargs: dict[str, Any] | None = None, | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | |
| callback=None, | |
| mp_ctx=None, | |
| model: Model | None = None, | |
| blas_cores: int | None | Literal["auto"] = "auto", | |
| compile_kwargs: dict | None = None, | |
| background: Literal[True], | |
| **kwargs, | |
| ) -> BackgroundSampleHandle: ... | |
| @overload | |
| def sample( | |
| draws: int = 1000, | |
| *, | |
| tune: int = 1000, | |
| chains: int | None = None, | |
| cores: int | None = None, | |
| random_seed: RandomState = None, | |
| progressbar: bool | ProgressBarType = True, | |
| progressbar_theme: Theme | None = default_progress_theme, | |
| step=None, | |
| var_names: Sequence[str] | None = None, | |
| nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | |
| initvals: StartDict | Sequence[StartDict | None] | None = None, | |
| init: str = "auto", | |
| jitter_max_retries: int = 10, | |
| n_init: int = 200_000, | |
| trace: TraceOrBackend | None = None, | |
| discard_tuned_samples: bool = True, | |
| compute_convergence_checks: bool = True, | |
| keep_warning_stat: bool = False, | |
| return_inferencedata: Literal[True] = True, | |
| idata_kwargs: dict[str, Any] | None = None, | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | |
| callback=None, | |
| mp_ctx=None, | |
| blas_cores: int | None | Literal["auto"] = "auto", | |
| compile_kwargs: dict | None = None, |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as the first overload: this overload doesn't account for background=True returning a BackgroundSampleHandle. Additional overloads are needed for the background mode case.
| compile_kwargs: dict | None = None, | |
| compile_kwargs: dict | None = None, | |
| background: Literal[True], | |
| **kwargs, | |
| ) -> BackgroundSampleHandle: ... | |
| @overload | |
| def sample( | |
| draws: int = 1000, | |
| *, | |
| tune: int = 1000, | |
| chains: int | None = None, | |
| cores: int | None = None, | |
| random_seed: RandomState = None, | |
| progressbar: bool | ProgressBarType = True, | |
| progressbar_theme: Theme | None = default_progress_theme, | |
| step=None, | |
| var_names: Sequence[str] | None = None, | |
| nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | |
| initvals: StartDict | Sequence[StartDict | None] | None = None, | |
| init: str = "auto", | |
| jitter_max_retries: int = 10, | |
| n_init: int = 200_000, | |
| trace: TraceOrBackend | None = None, | |
| discard_tuned_samples: bool = True, | |
| compute_convergence_checks: bool = True, | |
| keep_warning_stat: bool = False, | |
| return_inferencedata: Literal[False], | |
| idata_kwargs: dict[str, Any] | None = None, | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | |
| callback=None, | |
| mp_ctx=None, | |
| model: Model | None = None, | |
| blas_cores: int | None | Literal["auto"] = "auto", | |
| compile_kwargs: dict | None = None, |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main function signature's return type should include BackgroundSampleHandle as a possible return type: -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle
| ) -> InferenceData | MultiTrace | ZarrTrace: | |
| ) -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle: |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence appears to be an incomplete addition. The original docstring already had "Dictionary with keyword argument to pass to the functions compiled by the step methods." Adding "You can find a full list of arguments in the docstring of the step methods." on a new line makes it seem disconnected. Consider either merging this into a single paragraph or ensuring proper formatting.
| Dictionary with keyword argument to pass to the functions compiled by the step methods. | |
| You can find a full list of arguments in the docstring of the step methods. | |
| Dictionary with keyword arguments to pass to the functions compiled by the step methods. You can find a full list of arguments in the docstring of the step methods. |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The background parameter is missing from the "Parameters" section of the docstring. It should be documented there with a description like:
background : bool, default=False
If True, run sampling in a background thread and return a BackgroundSampleHandle
instead of the usual trace/InferenceData. The handle provides `done()`, `result()`,
and `exception()` methods. Currently only supported with `nuts_sampler="pymc"`.
Additionally, the "Background mode" section should be moved to a "Notes" section or integrated into the parameter documentation, as parameters sections typically document the parameters themselves.
| Background mode | |
| ---------------- | |
| - Set ``background=True`` to run sampling in a background thread; this returns a handle. | |
| - The handle supports ``done()``, ``result()``, and ``exception()``. | |
| - Progress bars are suppressed in background mode. | |
| - Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``. | |
| Returns | |
| ------- | |
| background : bool, default=False | |
| If True, run sampling in a background thread and return a BackgroundSampleHandle | |
| instead of the usual trace/InferenceData. The handle provides ``done()``, ``result()``, | |
| and ``exception()`` methods. Currently only supported with ``nuts_sampler="pymc"``. | |
| Notes | |
| ----- | |
| When ``background=True``, sampling is run in a background thread and a handle is returned. | |
| The handle supports ``done()``, ``result()``, and ``exception()`` methods. | |
| Progress bars are suppressed in background mode. | |
| Currently, background mode is limited to ``nuts_sampler="pymc"``; other samplers will raise | |
| ``NotImplementedError``. |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The progressbar is forcibly set to False when background=True, overriding the user's explicit progressbar parameter. This could be surprising if a user passes progressbar=True and expects some output. Consider instead checking if the user explicitly set progressbar and either warning them or raising an error if they try to use progressbar with background mode, rather than silently overriding it.
| raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") | |
| raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") | |
| # If the user explicitly requested progressbar=True, raise an error. | |
| if ( | |
| ("progressbar" in kwargs and kwargs["progressbar"] is True) | |
| or (locals().get("progressbar", None) is True) | |
| ): | |
| raise ValueError("Progress bars are not supported in background mode (background=True). Please set progressbar=False or omit it.") |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _background_internal flag is used to prevent infinite recursion when calling sample() from within the background thread. However, this pattern could be simplified. Consider using a dedicated internal function (e.g., _sample_impl) that contains the main logic, and have both the background mode and regular mode call that function. This would eliminate the need for the _background_internal flag and make the code clearer.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Copyright 2025 - present The PyMC Developers | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import pytest | ||
|
|
||
| import pymc as pm | ||
|
|
||
|
|
||
| def test_background_sampling_happy_path(): | ||
| with pm.Model(): | ||
| pm.Normal("x", 0, 1) | ||
| handle = pm.sample( | ||
| draws=20, | ||
| tune=10, | ||
| chains=1, | ||
| cores=1, | ||
| background=True, | ||
| progressbar=False, | ||
| ) | ||
| idata = handle.result() | ||
| assert hasattr(idata, "posterior") | ||
| assert idata.posterior.sizes["chain"] >= 1 | ||
|
|
||
|
|
||
| def test_background_sampling_raises(): | ||
| with pm.Model(): | ||
| pm.Normal("x", 0, sigma=-1) | ||
| handle = pm.sample( | ||
| draws=10, | ||
| tune=5, | ||
| chains=1, | ||
| cores=1, | ||
| background=True, | ||
| progressbar=False, | ||
| ) | ||
| with pytest.raises(Exception): | ||
| handle.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
BackgroundSampleHandleclass is missing a docstring. As a public-facing API component (returned bypm.samplewhenbackground=True), it should have comprehensive documentation explaining:done(),result(), andexception()methods__init__method