Skip to content

Commit

Permalink
🔥 remove vars from sample_posterior_predictive, other random refactor…
Browse files Browse the repository at this point in the history
…ings
  • Loading branch information
MarcoGorelli committed Dec 14, 2020
1 parent dbcc49e commit 64473a3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 36 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).

- In `sample_posterior_predictive` the `vars` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc3/pull/4343)).

## PyMC3 3.10.0 (7 December 2020)

Expand Down
38 changes: 16 additions & 22 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from theano.tensor import Tensor

import pymc3 as pm

Expand Down Expand Up @@ -561,12 +560,11 @@ def sample(
_log.debug("Pickling error:", exec_info=True)
parallel = False
except AttributeError as e:
if str(e).startswith("AttributeError: Can't pickle"):
_log.warning("Could not pickle model, sampling singlethreaded.")
_log.debug("Pickling error:", exec_info=True)
parallel = False
else:
if not str(e).startswith("AttributeError: Can't pickle"):
raise
_log.warning("Could not pickle model, sampling singlethreaded.")
_log.debug("Pickling error:", exec_info=True)
parallel = False
if not parallel:
if has_population_samplers:
has_demcmc = np.any(
Expand Down Expand Up @@ -1140,10 +1138,12 @@ def _run_secondary(c, stepper_dumps, secondary_end):
# the stepper is not necessarily a PopulationArraySharedStep itself,
# but rather a CompoundStep. PopulationArrayStepShared.population
# has to be updated, therefore we identify the substeppers first.
population_steppers = []
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
if isinstance(sm, PopulationArrayStepShared):
population_steppers.append(sm)
population_steppers = [
sm
for sm in (stepper.methods if isinstance(stepper, CompoundStep) else [stepper])
if isinstance(sm, PopulationArrayStepShared)
]

while True:
incoming = secondary_end.recv()
# receiving a None is the signal to exit
Expand Down Expand Up @@ -1602,7 +1602,6 @@ def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
model: Optional[Model] = None,
vars: Optional[Iterable[Tensor]] = None,
var_names: Optional[List[str]] = None,
size: Optional[int] = None,
keep_size: Optional[bool] = False,
Expand Down Expand Up @@ -1696,14 +1695,9 @@ def sample_posterior_predictive(
model = modelcontext(model)

if var_names is not None:
if vars is not None:
raise IncorrectArgumentsError("Should not specify both vars and var_names arguments.")
else:
vars = [model[x] for x in var_names]
elif vars is not None: # var_names is None, and vars is not.
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
if vars is None:
vars = model.observed_RVs
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs

if random_seed is not None:
np.random.seed(random_seed)
Expand All @@ -1729,8 +1723,8 @@ def sample_posterior_predictive(
else:
param = _trace[idx % len_trace]

values = draw_values(vars, point=param, size=size)
for k, v in zip(vars, values):
values = draw_values(vars_, point=param, size=size)
for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
except KeyboardInterrupt:
pass
Expand Down Expand Up @@ -1809,7 +1803,7 @@ def sample_posterior_predictive_w(
raise ValueError("The number of models and weights should be the same")

length_morv = len(models[0].observed_RVs)
if not all(len(i.observed_RVs) == length_morv for i in models):
if any(len(i.observed_RVs) != length_morv for i in models):
raise ValueError("The number of observed RVs should be the same for all models")

weights = np.asarray(weights)
Expand Down
14 changes: 1 addition & 13 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,7 @@ def test_normal_scalar(self):
ppc0 = pm.sample_posterior_predictive([model.test_point], samples=10)
ppc0 = pm.fast_sample_posterior_predictive([model.test_point], samples=10)
# deprecated argument is not introduced to fast version [2019/08/20:rpg]
with pytest.warns(DeprecationWarning):
ppc = pm.sample_posterior_predictive(trace, vars=[a])
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
# test empty ppc
ppc = pm.sample_posterior_predictive(trace, var_names=[])
assert len(ppc) == 0
Expand Down Expand Up @@ -518,8 +517,6 @@ def test_exceptions(self, caplog):
# Not for fast_sample_posterior_predictive
with pytest.raises(IncorrectArgumentsError):
ppc = pm.sample_posterior_predictive(trace, size=4, keep_size=True)
with pytest.raises(IncorrectArgumentsError):
ppc = pm.sample_posterior_predictive(trace, vars=[a], var_names=["a"])
# test wrong type argument
bad_trace = {"mu": stats.norm.rvs(size=1000)}
with pytest.raises(TypeError):
Expand Down Expand Up @@ -653,16 +650,7 @@ def test_deterministic_of_observed(self):

trace = pm.sample(100, chains=nchains)
np.random.seed(0)
with pytest.warns(DeprecationWarning):
ppc = pm.sample_posterior_predictive(
model=model,
trace=trace,
samples=len(trace) * nchains,
vars=(model.deterministics + model.basic_RVs),
)

rtol = 1e-5 if theano.config.floatX == "float64" else 1e-4
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)

np.random.seed(0)
ppc = pm.sample_posterior_predictive(
Expand Down

0 comments on commit 64473a3

Please sign in to comment.