Skip to content

Conversation

@XPE-7
Copy link

@XPE-7 XPE-7 commented Dec 10, 2025

Description

This adds an opt-in background mode to pm.sample so users can start sampling and continue working while it runs in a background thread. When background=True, pm.sample returns a handle with done(), result(), and exception() helpers, and the progress bar is suppressed to keep output clean. Background mode is currently limited to the built-in nuts_sampler="pymc"; other samplers raise NotImplementedError to make the scope clear. The model is resolved eagerly so the background thread always has a valid context.

Docstring notes describe how to use background=True and the current limitations. Tests cover both a successful background run and error propagation (tests/sampling/test_background_sampling.py).

Summary of changes:

  • Add background flag to pm.sample and return a BackgroundSampleHandle when enabled.
  • Suppress progress bars in background mode.
  • Resolve the model before launching the background thread.
  • Document background mode in the pm.sample docstring.
  • Add tests for background sampling (happy path and error propagation).

Related Issue

Checklist

  • Checked that pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks)
  • If you are a pro: each commit corresponds to a relevant logical change

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance

Copilot AI review requested due to automatic review settings December 10, 2025 11:04
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an opt-in background sampling mode to pm.sample, allowing users to start sampling in a background thread and continue working while it runs. When background=True, the function returns a BackgroundSampleHandle with done(), result(), and exception() helper methods for managing the background execution.

  • Added BackgroundSampleHandle class to manage background sampling threads
  • Extended pm.sample with a background parameter (limited to nuts_sampler="pymc")
  • Model resolution happens eagerly before thread creation to ensure valid context
  • Tests cover successful execution and error propagation scenarios

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 14 comments.

File Description
pymc/sampling/mcmc.py Adds BackgroundSampleHandle class, extends pm.sample with background parameter, includes recursive call prevention with _background_internal flag, and updates docstring with background mode documentation
tests/sampling/test_background_sampling.py New test file with happy path test and error propagation test for background sampling functionality

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

import sys
import time
import warnings
import threading
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment explaining why threading is imported, to help with code navigation. For example: import threading # For BackgroundSampleHandle

Suggested change
import threading
import threading # For BackgroundSampleHandle

Copilot uses AI. Check for mistakes.
Comment on lines 620 to +664
Dictionary with keyword argument to pass to the functions compiled by the step methods.
You can find a full list of arguments in the docstring of the step methods.
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence appears to be an incomplete addition. The original docstring already had "Dictionary with keyword argument to pass to the functions compiled by the step methods." Adding "You can find a full list of arguments in the docstring of the step methods." on a new line makes it seem disconnected. Consider either merging this into a single paragraph or ensuring proper formatting.

Suggested change
Dictionary with keyword argument to pass to the functions compiled by the step methods.
You can find a full list of arguments in the docstring of the step methods.
Dictionary with keyword arguments to pass to the functions compiled by the step methods. You can find a full list of arguments in the docstring of the step methods.

Copilot uses AI. Check for mistakes.
Comment on lines +686 to +723
if background and not _background_internal:
if nuts_sampler != "pymc":
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
progressbar = False

In [1]: import pymc as pm
...: n = 100
...: h = 61
...: alpha = 2
...: beta = 2
# Resolve the model now so the background thread has a concrete model object.
resolved_model = modelcontext(model)

In [2]: with pm.Model() as model: # context management
...: p = pm.Beta("p", alpha=alpha, beta=beta)
...: y = pm.Binomial("y", n=n, p=p, observed=h)
...: idata = pm.sample()
def _run():
return sample(
draws=draws,
tune=tune,
chains=chains,
cores=cores,
random_seed=random_seed,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
step=step,
var_names=var_names,
nuts_sampler=nuts_sampler,
initvals=initvals,
init=init,
jitter_max_retries=jitter_max_retries,
n_init=n_init,
trace=trace,
discard_tuned_samples=discard_tuned_samples,
compute_convergence_checks=compute_convergence_checks,
keep_warning_stat=keep_warning_stat,
return_inferencedata=return_inferencedata,
idata_kwargs=idata_kwargs,
nuts_sampler_kwargs=nuts_sampler_kwargs,
callback=callback,
mp_ctx=mp_ctx,
blas_cores=blas_cores,
model=resolved_model,
compile_kwargs=compile_kwargs,
background=False,
_background_internal=True,
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _background_internal flag is used to prevent infinite recursion when calling sample() from within the background thread. However, this pattern could be simplified. Consider using a dedicated internal function (e.g., _sample_impl) that contains the main logic, and have both the background mode and regular mode call that function. This would eliminate the need for the _background_internal flag and make the code clearer.

Copilot uses AI. Check for mistakes.
Step: TypeAlias = BlockedStep | CompoundStep


class BackgroundSampleHandle:
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BackgroundSampleHandle class is missing a docstring. As a public-facing API component (returned by pm.sample when background=True), it should have comprehensive documentation explaining:

  • The purpose of the class
  • How to use the done(), result(), and exception() methods
  • Parameters for the __init__ method
  • Return types and behavior of each method
  • Example usage
Suggested change
class BackgroundSampleHandle:
class BackgroundSampleHandle:
"""
Handle for background sampling threads returned by `pm.sample(background=True)`.
This class provides an interface to monitor and retrieve the result of a background
sampling process started by PyMC's `pm.sample` function with `background=True`.
It mimics the interface of `concurrent.futures.Future` for convenience.
Parameters
----------
target : callable
The function to execute in the background thread.
args : tuple, optional
Positional arguments to pass to the target function.
kwargs : dict, optional
Keyword arguments to pass to the target function.
Methods
-------
start()
Start the background thread.
done() -> bool
Return True if the background sampling has finished.
result(timeout=None)
Wait for the background sampling to finish and return the result.
Raises any exception raised by the target function.
exception(timeout=None)
Wait for the background sampling to finish and return the exception, if any.
Example
-------
>>> handle = pm.sample(..., background=True)
>>> # Do other work while sampling runs in the background
>>> if handle.done():
... trace = handle.result()
... else:
... print("Sampling still in progress")
"""

Copilot uses AI. Check for mistakes.
Comment on lines 665 to 674
Background mode
----------------
- Set ``background=True`` to run sampling in a background thread; this returns a handle.
- The handle supports ``done()``, ``result()``, and ``exception()``.
- Progress bars are suppressed in background mode.
- Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``.
Returns
-------
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The background parameter is missing from the "Parameters" section of the docstring. It should be documented there with a description like:

background : bool, default=False
    If True, run sampling in a background thread and return a BackgroundSampleHandle 
    instead of the usual trace/InferenceData. The handle provides `done()`, `result()`, 
    and `exception()` methods. Currently only supported with `nuts_sampler="pymc"`.

Additionally, the "Background mode" section should be moved to a "Notes" section or integrated into the parameter documentation, as parameters sections typically document the parameters themselves.

Suggested change
Background mode
----------------
- Set ``background=True`` to run sampling in a background thread; this returns a handle.
- The handle supports ``done()``, ``result()``, and ``exception()``.
- Progress bars are suppressed in background mode.
- Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``.
Returns
-------
background : bool, default=False
If True, run sampling in a background thread and return a BackgroundSampleHandle
instead of the usual trace/InferenceData. The handle provides ``done()``, ``result()``,
and ``exception()`` methods. Currently only supported with ``nuts_sampler="pymc"``.
Notes
-----
When ``background=True``, sampling is run in a background thread and a handle is returned.
The handle supports ``done()``, ``result()``, and ``exception()`` methods.
Progress bars are suppressed in background mode.
Currently, background mode is limited to ``nuts_sampler="pymc"``; other samplers will raise
``NotImplementedError``.

Copilot uses AI. Check for mistakes.
"""
if background and not _background_internal:
if nuts_sampler != "pymc":
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The progressbar is forcibly set to False when background=True, overriding the user's explicit progressbar parameter. This could be surprising if a user passes progressbar=True and expects some output. Consider instead checking if the user explicitly set progressbar and either warning them or raising an error if they try to use progressbar with background mode, rather than silently overriding it.

Suggested change
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only")
# If the user explicitly requested progressbar=True, raise an error.
if (
("progressbar" in kwargs and kwargs["progressbar"] is True)
or (locals().get("progressbar", None) is True)
):
raise ValueError("Progress bars are not supported in background mode (background=True). Please set progressbar=False or omit it.")

Copilot uses AI. Check for mistakes.
background: bool = False,
_background_internal: bool = False,
**kwargs,
) -> InferenceData | MultiTrace | ZarrTrace:
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main function signature's return type should include BackgroundSampleHandle as a possible return type: -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle

Suggested change
) -> InferenceData | MultiTrace | ZarrTrace:
) -> InferenceData | MultiTrace | ZarrTrace | BackgroundSampleHandle:

Copilot uses AI. Check for mistakes.
Comment on lines 1 to 33
import pymc as pm
import pytest


def test_background_sampling_happy_path():
with pm.Model():
pm.Normal("x", 0, 1)
handle = pm.sample(
draws=20,
tune=10,
chains=1,
cores=1,
background=True,
progressbar=False,
)
idata = handle.result()
assert hasattr(idata, "posterior")
assert idata.posterior.sizes["chain"] >= 1


def test_background_sampling_raises():
with pm.Model():
pm.Normal("x", 0, sigma=-1)
handle = pm.sample(
draws=10,
tune=5,
chains=1,
cores=1,
background=True,
progressbar=False,
)
with pytest.raises(Exception):
handle.result() No newline at end of file
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage for the NotImplementedError that should be raised when using background=True with nuts_sampler other than "pymc" (e.g., "nutpie", "numpyro", or "blackjax"). This is an important edge case mentioned in the documentation.

Copilot uses AI. Check for mistakes.
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function overloads don't account for background=True returning a BackgroundSampleHandle. The overloads should have additional signatures to specify that when background=True, the return type is BackgroundSampleHandle, not InferenceData or MultiTrace. For example:

@overload
def sample(
    ...,
    background: Literal[True],
    **kwargs,
) -> BackgroundSampleHandle: ...

Without this, type checkers won't correctly infer the return type when using background mode.

Suggested change
compile_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...
@overload
def sample(
draws: int = 1000,
*,
tune: int = 1000,
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: TraceOrBackend | None = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
return_inferencedata: Literal[False],
idata_kwargs: dict[str, Any] | None = None,
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...
@overload
def sample(
draws: int = 1000,
*,
tune: int = 1000,
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: TraceOrBackend | None = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
return_inferencedata: Literal[True] = True,
idata_kwargs: dict[str, Any] | None = None,
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,

Copilot uses AI. Check for mistakes.
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the first overload: this overload doesn't account for background=True returning a BackgroundSampleHandle. Additional overloads are needed for the background mode case.

Suggested change
compile_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
background: Literal[True],
**kwargs,
) -> BackgroundSampleHandle: ...
@overload
def sample(
draws: int = 1000,
*,
tune: int = 1000,
chains: int | None = None,
cores: int | None = None,
random_seed: RandomState = None,
progressbar: bool | ProgressBarType = True,
progressbar_theme: Theme | None = default_progress_theme,
step=None,
var_names: Sequence[str] | None = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
initvals: StartDict | Sequence[StartDict | None] | None = None,
init: str = "auto",
jitter_max_retries: int = 10,
n_init: int = 200_000,
trace: TraceOrBackend | None = None,
discard_tuned_samples: bool = True,
compute_convergence_checks: bool = True,
keep_warning_stat: bool = False,
return_inferencedata: Literal[False],
idata_kwargs: dict[str, Any] | None = None,
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Non-blocking sampling

1 participant