From 6c5254fe8fe7cbcf7fdf775db67191beb1dd7c90 Mon Sep 17 00:00:00 2001 From: michaelosthege Date: Wed, 11 Mar 2020 12:34:17 +0100 Subject: [PATCH] Include n_tune, n_draws and t_sampling in SamplerReport (#3827) * include n_tune, n_draws and t_sampling in SamplerReport * count tune/draw samples instead of trusting parameters (because of KeyboardInterrupt) * fall back to tune and len(trace) if tune stat is unavailable * add test for SamplerReport n_tune and n_draws * clarify that n_tune are not necessarily in the trace * use actual number of chains to compute totals * mention new SamplerReport properties in release notes Co-authored-by: Michael Osthege --- RELEASE-NOTES.md | 1 + pymc3/backends/report.py | 25 ++++++++++++++++++++++++- pymc3/sampling.py | 34 ++++++++++++++++++++++++++++++++-- pymc3/tests/test_sampling.py | 16 ++++++++++++++++ 4 files changed, 73 insertions(+), 3 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 26d22622bf..3010f150ac 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -8,6 +8,7 @@ - `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info. - Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section. - Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models. +- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827)) ### Maintenance - Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6. diff --git a/pymc3/backends/report.py b/pymc3/backends/report.py index 667878af94..d632490b23 100644 --- a/pymc3/backends/report.py +++ b/pymc3/backends/report.py @@ -15,6 +15,7 @@ from collections import namedtuple import logging import enum +import typing from ..util import is_transformed_name, get_untransformed_name @@ -51,11 +52,15 @@ class WarningType(enum.Enum): class SamplerReport: + """This object bundles warnings, convergence statistics and metadata of a sampling run.""" def __init__(self): self._chain_warnings = {} self._global_warnings = [] self._ess = None self._rhat = None + self._n_tune = None + self._n_draws = None + self._t_sampling = None @property def _warnings(self): @@ -68,6 +73,25 @@ def ok(self): return all(_LEVELS[warn.level] < _LEVELS['warn'] for warn in self._warnings) + @property + def n_tune(self) -> typing.Optional[int]: + """Number of tune iterations - not necessarily kept in trace!""" + return self._n_tune + + @property + def n_draws(self) -> typing.Optional[int]: + """Number of draw iterations.""" + return self._n_draws + + @property + def t_sampling(self) -> typing.Optional[float]: + """ + Number of seconds that the sampling procedure took. + + (Includes parallelization overhead.) + """ + return self._t_sampling + def raise_ok(self, level='error'): errors = [warn for warn in self._warnings if _LEVELS[warn.level] >= _LEVELS[level]] @@ -151,7 +175,6 @@ def _add_warnings(self, warnings, chain=None): warn_list.extend(warnings) def _log_summary(self): - def log_warning(warn): level = _LEVELS[warn.level] logger.log(level, warn.message) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 6115ffd6ff..0a418cd627 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -24,6 +24,7 @@ from copy import copy import pickle import logging +import time import warnings import numpy as np @@ -488,6 +489,7 @@ def sample( ) parallel = cores > 1 and chains > 1 and not has_population_samplers + t_start = time.time() if parallel: _log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores)) _print_step_hierarchy(step) @@ -533,8 +535,36 @@ def sample( _print_step_hierarchy(step) trace = _sample_many(**sample_args) - discard = tune if discard_tuned_samples else 0 - trace = trace[discard:] + t_sampling = time.time() - t_start + # count the number of tune/draw iterations that happened + # ideally via the "tune" statistic, but not all samplers record it! + if 'tune' in trace.stat_names: + stat = trace.get_sampler_stats('tune', chains=0) + # when CompoundStep is used, the stat is 2 dimensional! + if len(stat.shape) == 2: + stat = stat[:,0] + stat = tuple(stat) + n_tune = stat.count(True) + n_draws = stat.count(False) + else: + # these may be wrong when KeyboardInterrupt happened, but they're better than nothing + n_tune = min(tune, len(trace)) + n_draws = max(0, len(trace) - n_tune) + + if discard_tuned_samples: + trace = trace[n_tune:] + + # save metadata in SamplerReport + trace.report._n_tune = n_tune + trace.report._n_draws = n_draws + trace.report._t_sampling = t_sampling + + n_chains = len(trace.chains) + _log.info( + f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' + f'({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) ' + f'took {trace.report.t_sampling:.0f} seconds.' + ) if compute_convergence_checks: if draws - tune < 100: diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 05c94f2709..0acf5ae72b 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -142,6 +142,22 @@ def test_sample_tune_len(self): trace = pm.sample(draws=100, tune=50, cores=4) assert len(trace) == 100 + @pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice]) + @pytest.mark.parametrize("discard", [True, False]) + def test_trace_report(self, step_cls, discard): + with self.model: + # add more variables, because stats are 2D with CompoundStep! + pm.Uniform('uni') + trace = pm.sample( + draws=100, tune=50, cores=1, + discard_tuned_samples=discard, + step=step_cls() + ) + assert trace.report.n_tune == 50 + assert trace.report.n_draws == 100 + assert isinstance(trace.report.t_sampling, float) + pass + @pytest.mark.parametrize('cores', [1, 2]) def test_sampler_stat_tune(self, cores): with self.model: