-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Labels
Description
Describe the issue:
sum function will break sample if observed.
Reproduceable code example:
import pymc as pm
with pm.Model() as m:
x = pm.Normal("x", mu=0, sigma=1e6)
y = pm.Normal.dist(x, shape=(5,))
y_sum = pm.Deterministic("y_sum", pm.math.sum(y))
with pm.observe(m, {"y_sum": 2.0}):
trace = pm.sample()Error message:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 3
1 #%%
2 with pm.observe(m, {"y_sum": 2.0}):
----> 3 trace = pm.sample(nuts_sampler='nutpie')
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:782, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
779 msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
780 _log.warning(msg)
--> 782 provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
783 exclusive_nuts = (
784 # User provided an instantiated NUTS step, and nothing else is needed
785 (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
(...) 792 )
793 )
795 if nuts_sampler != "pymc":
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\sampling\mcmc.py:245, in assign_step_methods(model, step, methods)
243 methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS)
244 selected_steps: dict[type[BlockedStep], list] = {}
--> 245 model_logp = model.logp()
247 for var in model.value_vars:
248 if var not in assigned_vars:
249 # determine if a gradient can be computed
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\model\core.py:714, in Model.logp(self, vars, jacobian, sum)
712 rv_logps: list[TensorVariable] = []
713 if rvs:
--> 714 rv_logps = transformed_conditional_logp(
715 rvs=rvs,
716 rvs_to_values=self.rvs_to_values,
717 rvs_to_transforms=self.rvs_to_transforms,
718 jacobian=jacobian,
719 )
720 assert isinstance(rv_logps, list)
722 # Replace random variables by their value variables in potential terms
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:574, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
571 transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore[arg-type]
573 kwargs.setdefault("warn_rvs", False)
--> 574 temp_logp_terms = conditional_logp(
575 rvs_to_values,
576 extra_rewrites=transform_rewrite,
577 use_jacobian=jacobian,
578 **kwargs,
579 )
581 # The function returns the logp for every single value term we provided to it.
582 # This includes the extra values we plugged in above, so we filter those we
583 # actually wanted in the same order they were given in.
584 logp_terms = {}
File d:\code\examples\pymc0\.venv\Lib\site-packages\pymc\logprob\basic.py:531, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
529 missing_value_terms = set(original_values) - set(values_to_logprobs)
530 if missing_value_terms:
--> 531 raise RuntimeError(
532 f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
533 )
535 # Ensure same order as input
536 logprobs = cleanup_ir(tuple(values_to_logprobs[v] for v in original_values))
RuntimeError: The logprob terms of the following value variables could not be derived: {TensorConstant(TensorType(float64, shape=()), data=array(2.))}PyMC version information:
5.26.1
Context for the issue:
observing sum/max etc will be very helpful for many cases