-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add background sampling handle to pm.sample #7991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds an opt-in background sampling mode to pm.sample, allowing users to start sampling in a background thread and continue working while it runs. When background=True, the function returns a BackgroundSampleHandle with done(), result(), and exception() helper methods for managing the background execution.
- Added
BackgroundSampleHandleclass to manage background sampling threads - Extended
pm.samplewith abackgroundparameter (limited tonuts_sampler="pymc") - Model resolution happens eagerly before thread creation to ensure valid context
- Tests cover successful execution and error propagation scenarios
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 14 comments.
| File | Description |
|---|---|
pymc/sampling/mcmc.py |
Adds BackgroundSampleHandle class, extends pm.sample with background parameter, includes recursive call prevention with _background_internal flag, and updates docstring with background mode documentation |
tests/sampling/test_background_sampling.py |
New test file with happy path test and error propagation test for background sampling functionality |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
pymc/sampling/mcmc.py
Outdated
| import sys | ||
| import time | ||
| import warnings | ||
| import threading |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a comment explaining why threading is imported, to help with code navigation. For example: import threading # For BackgroundSampleHandle
| import threading | |
| import threading # For BackgroundSampleHandle |
| Dictionary with keyword argument to pass to the functions compiled by the step methods. | ||
| You can find a full list of arguments in the docstring of the step methods. |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence appears to be an incomplete addition. The original docstring already had "Dictionary with keyword argument to pass to the functions compiled by the step methods." Adding "You can find a full list of arguments in the docstring of the step methods." on a new line makes it seem disconnected. Consider either merging this into a single paragraph or ensuring proper formatting.
| Dictionary with keyword argument to pass to the functions compiled by the step methods. | |
| You can find a full list of arguments in the docstring of the step methods. | |
| Dictionary with keyword arguments to pass to the functions compiled by the step methods. You can find a full list of arguments in the docstring of the step methods. |
| if background and not _background_internal: | ||
| if nuts_sampler != "pymc": | ||
| raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") | ||
| progressbar = False | ||
|
|
||
| In [1]: import pymc as pm | ||
| ...: n = 100 | ||
| ...: h = 61 | ||
| ...: alpha = 2 | ||
| ...: beta = 2 | ||
| # Resolve the model now so the background thread has a concrete model object. | ||
| resolved_model = modelcontext(model) | ||
|
|
||
| In [2]: with pm.Model() as model: # context management | ||
| ...: p = pm.Beta("p", alpha=alpha, beta=beta) | ||
| ...: y = pm.Binomial("y", n=n, p=p, observed=h) | ||
| ...: idata = pm.sample() | ||
| def _run(): | ||
| return sample( | ||
| draws=draws, | ||
| tune=tune, | ||
| chains=chains, | ||
| cores=cores, | ||
| random_seed=random_seed, | ||
| progressbar=progressbar, | ||
| progressbar_theme=progressbar_theme, | ||
| step=step, | ||
| var_names=var_names, | ||
| nuts_sampler=nuts_sampler, | ||
| initvals=initvals, | ||
| init=init, | ||
| jitter_max_retries=jitter_max_retries, | ||
| n_init=n_init, | ||
| trace=trace, | ||
| discard_tuned_samples=discard_tuned_samples, | ||
| compute_convergence_checks=compute_convergence_checks, | ||
| keep_warning_stat=keep_warning_stat, | ||
| return_inferencedata=return_inferencedata, | ||
| idata_kwargs=idata_kwargs, | ||
| nuts_sampler_kwargs=nuts_sampler_kwargs, | ||
| callback=callback, | ||
| mp_ctx=mp_ctx, | ||
| blas_cores=blas_cores, | ||
| model=resolved_model, | ||
| compile_kwargs=compile_kwargs, | ||
| background=False, | ||
| _background_internal=True, |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _background_internal flag is used to prevent infinite recursion when calling sample() from within the background thread. However, this pattern could be simplified. Consider using a dedicated internal function (e.g., _sample_impl) that contains the main logic, and have both the background mode and regular mode call that function. This would eliminate the need for the _background_internal flag and make the code clearer.
| Step: TypeAlias = BlockedStep | CompoundStep | ||
|
|
||
|
|
||
| class BackgroundSampleHandle: |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BackgroundSampleHandle class is missing a docstring. As a public-facing API component (returned by pm.sample when background=True), it should have comprehensive documentation explaining:
- The purpose of the class
- How to use the
done(),result(), andexception()methods - Parameters for the
__init__method - Return types and behavior of each method
- Example usage
| class BackgroundSampleHandle: | |
| class BackgroundSampleHandle: | |
| """ | |
| Handle for background sampling threads returned by `pm.sample(background=True)`. | |
| This class provides an interface to monitor and retrieve the result of a background | |
| sampling process started by PyMC's `pm.sample` function with `background=True`. | |
| It mimics the interface of `concurrent.futures.Future` for convenience. | |
| Parameters | |
| ---------- | |
| target : callable | |
| The function to execute in the background thread. | |
| args : tuple, optional | |
| Positional arguments to pass to the target function. | |
| kwargs : dict, optional | |
| Keyword arguments to pass to the target function. | |
| Methods | |
| ------- | |
| start() | |
| Start the background thread. | |
| done() -> bool | |
| Return True if the background sampling has finished. | |
| result(timeout=None) | |
| Wait for the background sampling to finish and return the result. | |
| Raises any exception raised by the target function. | |
| exception(timeout=None) | |
| Wait for the background sampling to finish and return the exception, if any. | |
| Example | |
| ------- | |
| >>> handle = pm.sample(..., background=True) | |
| >>> # Do other work while sampling runs in the background | |
| >>> if handle.done(): | |
| ... trace = handle.result() | |
| ... else: | |
| ... print("Sampling still in progress") | |
| """ |
| Background mode | ||
| ---------------- | ||
| - Set ``background=True`` to run sampling in a background thread; this returns a handle. | ||
| - The handle supports ``done()``, ``result()``, and ``exception()``. | ||
| - Progress bars are suppressed in background mode. | ||
| - Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``. | ||
| Returns | ||
| ------- |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The background parameter is missing from the "Parameters" section of the docstring. It should be documented there with a description like:
background : bool, default=False
If True, run sampling in a background thread and return a BackgroundSampleHandle
instead of the usual trace/InferenceData. The handle provides `done()`, `result()`,
and `exception()` methods. Currently only supported with `nuts_sampler="pymc"`.
Additionally, the "Background mode" section should be moved to a "Notes" section or integrated into the parameter documentation, as parameters sections typically document the parameters themselves.
| Background mode | |
| ---------------- | |
| - Set ``background=True`` to run sampling in a background thread; this returns a handle. | |
| - The handle supports ``done()``, ``result()``, and ``exception()``. | |
| - Progress bars are suppressed in background mode. | |
| - Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``. | |
| Returns | |
| ------- | |
| background : bool, default=False | |
| If True, run sampling in a background thread and return a BackgroundSampleHandle | |
| instead of the usual trace/InferenceData. The handle provides ``done()``, ``result()``, | |
| and ``exception()`` methods. Currently only supported with ``nuts_sampler="pymc"``. | |
| Notes | |
| ----- | |
| When ``background=True``, sampling is run in a background thread and a handle is returned. | |
| The handle supports ``done()``, ``result()``, and ``exception()`` methods. | |
| Progress bars are suppressed in background mode. | |
| Currently, background mode is limited to ``nuts_sampler="pymc"``; other samplers will raise | |
| ``NotImplementedError``. |
| """ | ||
| if background and not _background_internal: | ||
| if nuts_sampler != "pymc": | ||
| raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The progressbar is forcibly set to False when background=True, overriding the user's explicit progressbar parameter. This could be surprising if a user passes progressbar=True and expects some output. Consider instead checking if the user explicitly set progressbar and either warning them or raising an error if they try to use progressbar with background mode, rather than silently overriding it.
| raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") | |
| raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") | |
| # If the user explicitly requested progressbar=True, raise an error. | |
| if ( | |
| ("progressbar" in kwargs and kwargs["progressbar"] is True) | |
| or (locals().get("progressbar", None) is True) | |
| ): | |
| raise ValueError("Progress bars are not supported in background mode (background=True). Please set progressbar=False or omit it.") |
| background: bool = False, | ||
| _background_internal: bool = False, | ||
| **kwargs, | ||
| ) -> InferenceData | MultiTrace | ZarrTrace: |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main function signature's return type should include BackgroundSampleHandle as a possible return type: -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle
| ) -> InferenceData | MultiTrace | ZarrTrace: | |
| ) -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle: |
| import pymc as pm | ||
| import pytest | ||
|
|
||
|
|
||
| def test_background_sampling_happy_path(): | ||
| with pm.Model(): | ||
| pm.Normal("x", 0, 1) | ||
| handle = pm.sample( | ||
| draws=20, | ||
| tune=10, | ||
| chains=1, | ||
| cores=1, | ||
| background=True, | ||
| progressbar=False, | ||
| ) | ||
| idata = handle.result() | ||
| assert hasattr(idata, "posterior") | ||
| assert idata.posterior.sizes["chain"] >= 1 | ||
|
|
||
|
|
||
| def test_background_sampling_raises(): | ||
| with pm.Model(): | ||
| pm.Normal("x", 0, sigma=-1) | ||
| handle = pm.sample( | ||
| draws=10, | ||
| tune=5, | ||
| chains=1, | ||
| cores=1, | ||
| background=True, | ||
| progressbar=False, | ||
| ) | ||
| with pytest.raises(Exception): | ||
| handle.result() No newline at end of file |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test coverage for the NotImplementedError that should be raised when using background=True with nuts_sampler other than "pymc" (e.g., "nutpie", "numpyro", or "blackjax"). This is an important edge case mentioned in the documentation.
| callback=None, | ||
| mp_ctx=None, | ||
| blas_cores: int | None | Literal["auto"] = "auto", | ||
| compile_kwargs: dict | None = None, |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function overloads don't account for background=True returning a BackgroundSampleHandle. The overloads should have additional signatures to specify that when background=True, the return type is BackgroundSampleHandle, not InferenceData or MultiTrace. For example:
@overload
def sample(
...,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...Without this, type checkers won't correctly infer the return type when using background mode.
| compile_kwargs: dict | None = None, | |
| compile_kwargs: dict | None = None, | |
| background: Literal[True], | |
| **kwargs, | |
| ) -> BackgroundSampleHandle: ... | |
| @overload | |
| def sample( | |
| draws: int = 1000, | |
| *, | |
| tune: int = 1000, | |
| chains: int | None = None, | |
| cores: int | None = None, | |
| random_seed: RandomState = None, | |
| progressbar: bool | ProgressBarType = True, | |
| progressbar_theme: Theme | None = default_progress_theme, | |
| step=None, | |
| var_names: Sequence[str] | None = None, | |
| nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | |
| initvals: StartDict | Sequence[StartDict | None] | None = None, | |
| init: str = "auto", | |
| jitter_max_retries: int = 10, | |
| n_init: int = 200_000, | |
| trace: TraceOrBackend | None = None, | |
| discard_tuned_samples: bool = True, | |
| compute_convergence_checks: bool = True, | |
| keep_warning_stat: bool = False, | |
| return_inferencedata: Literal[False], | |
| idata_kwargs: dict[str, Any] | None = None, | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | |
| callback=None, | |
| mp_ctx=None, | |
| model: Model | None = None, | |
| blas_cores: int | None | Literal["auto"] = "auto", | |
| compile_kwargs: dict | None = None, | |
| background: Literal[True], | |
| **kwargs, | |
| ) -> BackgroundSampleHandle: ... | |
| @overload | |
| def sample( | |
| draws: int = 1000, | |
| *, | |
| tune: int = 1000, | |
| chains: int | None = None, | |
| cores: int | None = None, | |
| random_seed: RandomState = None, | |
| progressbar: bool | ProgressBarType = True, | |
| progressbar_theme: Theme | None = default_progress_theme, | |
| step=None, | |
| var_names: Sequence[str] | None = None, | |
| nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | |
| initvals: StartDict | Sequence[StartDict | None] | None = None, | |
| init: str = "auto", | |
| jitter_max_retries: int = 10, | |
| n_init: int = 200_000, | |
| trace: TraceOrBackend | None = None, | |
| discard_tuned_samples: bool = True, | |
| compute_convergence_checks: bool = True, | |
| keep_warning_stat: bool = False, | |
| return_inferencedata: Literal[True] = True, | |
| idata_kwargs: dict[str, Any] | None = None, | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | |
| callback=None, | |
| mp_ctx=None, | |
| blas_cores: int | None | Literal["auto"] = "auto", | |
| compile_kwargs: dict | None = None, |
| mp_ctx=None, | ||
| model: Model | None = None, | ||
| blas_cores: int | None | Literal["auto"] = "auto", | ||
| compile_kwargs: dict | None = None, |
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as the first overload: this overload doesn't account for background=True returning a BackgroundSampleHandle. Additional overloads are needed for the background mode case.
| compile_kwargs: dict | None = None, | |
| compile_kwargs: dict | None = None, | |
| background: Literal[True], | |
| **kwargs, | |
| ) -> BackgroundSampleHandle: ... | |
| @overload | |
| def sample( | |
| draws: int = 1000, | |
| *, | |
| tune: int = 1000, | |
| chains: int | None = None, | |
| cores: int | None = None, | |
| random_seed: RandomState = None, | |
| progressbar: bool | ProgressBarType = True, | |
| progressbar_theme: Theme | None = default_progress_theme, | |
| step=None, | |
| var_names: Sequence[str] | None = None, | |
| nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", | |
| initvals: StartDict | Sequence[StartDict | None] | None = None, | |
| init: str = "auto", | |
| jitter_max_retries: int = 10, | |
| n_init: int = 200_000, | |
| trace: TraceOrBackend | None = None, | |
| discard_tuned_samples: bool = True, | |
| compute_convergence_checks: bool = True, | |
| keep_warning_stat: bool = False, | |
| return_inferencedata: Literal[False], | |
| idata_kwargs: dict[str, Any] | None = None, | |
| nuts_sampler_kwargs: dict[str, Any] | None = None, | |
| callback=None, | |
| mp_ctx=None, | |
| model: Model | None = None, | |
| blas_cores: int | None | Literal["auto"] = "auto", | |
| compile_kwargs: dict | None = None, |
Description
This adds an opt-in background mode to
pm.sampleso users can start sampling and continue working while it runs in a background thread. Whenbackground=True,pm.samplereturns a handle withdone(),result(), andexception()helpers, and the progress bar is suppressed to keep output clean. Background mode is currently limited to the built-innuts_sampler="pymc"; other samplers raiseNotImplementedErrorto make the scope clear. The model is resolved eagerly so the background thread always has a valid context.Docstring notes describe how to use
background=Trueand the current limitations. Tests cover both a successful background run and error propagation (tests/sampling/test_background_sampling.py).Summary of changes:
backgroundflag topm.sampleand return aBackgroundSampleHandlewhen enabled.pm.sampledocstring.Related Issue
Checklist
Type of change