Skip to content

Commit

Permalink
fix default prior variable names (#3591)
Browse files Browse the repository at this point in the history
* fix default prior variable names

* update docs and add test on pm.Data not being in prior

* Add model.potentials to prior_vars
  • Loading branch information
OriolAbril authored and aloctavodia committed Aug 16, 2019
1 parent 64d2b88 commit 3a2a765
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 8 additions & 3 deletions pymc3/sampling.py
Expand Up @@ -1292,7 +1292,7 @@ def sample_prior_predictive(samples=500,
samples. *DEPRECATED* - Use ``var_names`` argument instead.
var_names : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. Defaults to ``model.named_vars``.
samples. Defaults to both observed and unobserved RVs.
random_seed : int
Seed for the random number generator.
Expand All @@ -1305,8 +1305,13 @@ def sample_prior_predictive(samples=500,
model = modelcontext(model)

if vars is None and var_names is None:
vars = set(model.named_vars.keys())
vars_ = model.named_vars
prior_pred_vars = model.observed_RVs
prior_vars = (
get_default_varnames(model.unobserved_RVs, include_transformed=True) +
model.potentials
)
vars_ = [var.name for var in prior_vars + prior_pred_vars]
vars = set(vars_)
elif vars is None:
vars = var_names
vars_ = vars
Expand Down
4 changes: 3 additions & 1 deletion pymc3/tests/test_sampling.py
Expand Up @@ -505,12 +505,14 @@ def test_ignores_observed(self):
observed = np.random.normal(10, 1, size=200)
with pm.Model():
# Use a prior that's way off to show we're ignoring the observed variables
observed_data = pm.Data("observed_data", observed)
mu = pm.Normal("mu", mu=-100, sigma=1)
positive_mu = pm.Deterministic("positive_mu", np.abs(mu))
z = -1 - positive_mu
pm.Normal("x_obs", mu=z, sigma=1, observed=observed)
pm.Normal("x_obs", mu=z, sigma=1, observed=observed_data)
prior = pm.sample_prior_predictive()

assert "observed_data" not in prior
assert (prior["mu"] < 90).all()
assert (prior["positive_mu"] > 90).all()
assert (prior["x_obs"] < 90).all()
Expand Down

0 comments on commit 3a2a765

Please sign in to comment.