Skip to content

Commit

Permalink
🔥 remove deprecated vars from sample_prior_predictive
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Dec 13, 2020
1 parent 60cf2cd commit 66f4ed6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 31 deletions.
33 changes: 6 additions & 27 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from collections import defaultdict
from copy import copy
from typing import Any, Dict, Iterable, List, Optional, Union, cast
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import arviz
import numpy as np
Expand Down Expand Up @@ -56,7 +56,7 @@
Metropolis,
Slice,
)
from pymc3.step_methods.arraystep import PopulationArrayStepShared
from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc3.step_methods.hmc import quadpotential
from pymc3.util import (
chains_and_samples,
Expand Down Expand Up @@ -91,18 +91,7 @@
CategoricalGibbsMetropolis,
PGBART,
)
Step = Union[
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
PGBART,
CompoundStep,
]

Step = Union[BlockedStep, CompoundStep]

ArrayLike = Union[np.ndarray, List[float]]
PointType = Dict[str, np.ndarray]
Expand Down Expand Up @@ -1898,7 +1887,6 @@ def sample_posterior_predictive_w(
def sample_prior_predictive(
samples=500,
model: Optional[Model] = None,
vars: Optional[Iterable[str]] = None,
var_names: Optional[Iterable[str]] = None,
random_seed=None,
) -> Dict[str, np.ndarray]:
Expand All @@ -1909,9 +1897,6 @@ def sample_prior_predictive(
samples : int
Number of samples from the prior predictive to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
vars : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. *DEPRECATED* - Use ``var_names`` argument instead.
var_names : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. Defaults to both observed and unobserved RVs.
Expand All @@ -1926,20 +1911,14 @@ def sample_prior_predictive(
"""
model = modelcontext(model)

if vars is None and var_names is None:
if var_names is None:
prior_pred_vars = model.observed_RVs
prior_vars = (
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
)
vars_: Iterable[str] = [var.name for var in prior_vars + prior_pred_vars]
elif vars is None:
assert var_names is not None # help mypy
vars_ = var_names
elif var_names is None:
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
vars_ = vars
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
else:
raise ValueError("Cannot supply both vars and var_names arguments.")
vars_ = set(var_names)

if random_seed is not None:
np.random.seed(random_seed)
Expand Down
3 changes: 2 additions & 1 deletion pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numpy.random import uniform

from pymc3.blocking import ArrayOrdering, DictToArrayBijection
from pymc3.model import modelcontext
from pymc3.model import PyMC3Variable, modelcontext
from pymc3.step_methods.compound import CompoundStep
from pymc3.theanof import inputvars
from pymc3.util import get_var_name
Expand Down Expand Up @@ -48,6 +48,7 @@ class BlockedStep:

generates_stats = False
stats_dtypes: List[Dict[str, np.dtype]] = []
vars: List[PyMC3Variable] = []

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
Expand Down
5 changes: 2 additions & 3 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,8 @@ def test_respects_shape(self):
with pm.Model():
mu = pm.Gamma("mu", 3, 1, shape=1)
goals = pm.Poisson("goals", mu, shape=shape)
with pytest.warns(DeprecationWarning):
trace1 = pm.sample_prior_predictive(10, vars=["mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
if shape == 2: # want to test shape as an int
shape = (2,)
assert trace1["goals"].shape == (10,) + shape
Expand Down

0 comments on commit 66f4ed6

Please sign in to comment.