diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 13b24c1bc6c..7ad83a3ec87 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -33,6 +33,7 @@ - Rewrote `Multinomial._random` method to better handle shape broadcasting (#3271) - Fixed `Rice` distribution, which inconsistently mixed two parametrizations (#3286). - `Rice` distribution now accepts multiple parameters and observations and is usable with NUTS (#3289). +- `sample_posterior_predictive` no longer calls `draw_values` to initialize the shape of the ppc trace. This called could lead to `ValueError`'s when sampling the ppc from a model with `Flat` or `HalfFlat` prior distributions (Fix issue #3294). ### Deprecations diff --git a/pymc3/sampling.py b/pymc3/sampling.py index e4dcbe35be9..c11f2ef5d45 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -1123,17 +1123,9 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size if progressbar: indices = tqdm(indices, total=samples) - varnames = [var.name for var in vars] - - # draw once to inspect the shape - var_values = list(zip(varnames, - draw_values(vars, point=model.test_point, size=size))) ppc_trace = defaultdict(list) - for varname, value in var_values: - ppc_trace[varname] = np.zeros((samples,) + value.shape, value.dtype) - try: - for slc, idx in enumerate(indices): + for idx in indices: if nchain > 1: chain_idx, point_idx = np.divmod(idx, len_trace) param = trace._straces[chain_idx % nchain].point(point_idx) @@ -1142,7 +1134,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size values = draw_values(vars, point=param, size=size) for k, v in zip(vars, values): - ppc_trace[k.name][slc] = v + ppc_trace[k.name].append(v) except KeyboardInterrupt: pass @@ -1151,7 +1143,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size if progressbar: indices.close() - return ppc_trace + return {k: np.asarray(v) for k, v in ppc_trace.items()} def sample_ppc(*args, **kwargs): diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 73c61eef684..4b43bedad18 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -289,6 +289,21 @@ def test_sum_normal(self): _, pval = stats.kstest(ppc['b'], stats.norm(scale=scale).cdf) assert pval > 0.001 + def test_model_not_drawable_prior(self): + data = np.random.poisson(lam=10, size=200) + model = pm.Model() + with model: + mu = pm.HalfFlat('sigma') + pm.Poisson('foo', mu=mu, observed=data) + trace = pm.sample(tune=1000) + + with model: + with pytest.raises(ValueError) as excinfo: + pm.sample_prior_predictive(50) + assert "Cannot sample" in str(excinfo.value) + samples = pm.sample_posterior_predictive(trace, 50) + assert samples['foo'].shape == (50, 200) + class TestSamplePPCW(SeededTest): def test_sample_posterior_predictive_w(self):