Skip to content

Commit

Permalink
Include n_tune, n_draws and t_sampling in SamplerReport (pymc-devs#3827)
Browse files Browse the repository at this point in the history
* include n_tune, n_draws and t_sampling in SamplerReport
* count tune/draw samples instead of trusting parameters (because of KeyboardInterrupt)
* fall back to tune and len(trace) if tune stat is unavailable
* add test for SamplerReport n_tune and n_draws
* clarify that n_tune are not necessarily in the trace
* use actual number of chains to compute totals
* mention new SamplerReport properties in release notes
Co-authored-by: Michael Osthege <m.osthege@fz-juelich.de>
  • Loading branch information
michaelosthege committed Mar 11, 2020
1 parent b5891be commit 6c5254f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 3 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))

### Maintenance
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
Expand Down
25 changes: 24 additions & 1 deletion pymc3/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import namedtuple
import logging
import enum
import typing
from ..util import is_transformed_name, get_untransformed_name


Expand Down Expand Up @@ -51,11 +52,15 @@ class WarningType(enum.Enum):


class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._ess = None
self._rhat = None
self._n_tune = None
self._n_draws = None
self._t_sampling = None

@property
def _warnings(self):
Expand All @@ -68,6 +73,25 @@ def ok(self):
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)

@property
def n_tune(self) -> typing.Optional[int]:
"""Number of tune iterations - not necessarily kept in trace!"""
return self._n_tune

@property
def n_draws(self) -> typing.Optional[int]:
"""Number of draw iterations."""
return self._n_draws

@property
def t_sampling(self) -> typing.Optional[float]:
"""
Number of seconds that the sampling procedure took.
(Includes parallelization overhead.)
"""
return self._t_sampling

def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
Expand Down Expand Up @@ -151,7 +175,6 @@ def _add_warnings(self, warnings, chain=None):
warn_list.extend(warnings)

def _log_summary(self):

def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)
Expand Down
34 changes: 32 additions & 2 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from copy import copy
import pickle
import logging
import time
import warnings

import numpy as np
Expand Down Expand Up @@ -488,6 +489,7 @@ def sample(
)

parallel = cores > 1 and chains > 1 and not has_population_samplers
t_start = time.time()
if parallel:
_log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores))
_print_step_hierarchy(step)
Expand Down Expand Up @@ -533,8 +535,36 @@ def sample(
_print_step_hierarchy(step)
trace = _sample_many(**sample_args)

discard = tune if discard_tuned_samples else 0
trace = trace[discard:]
t_sampling = time.time() - t_start
# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if 'tune' in trace.stat_names:
stat = trace.get_sampler_stats('tune', chains=0)
# when CompoundStep is used, the stat is 2 dimensional!
if len(stat.shape) == 2:
stat = stat[:,0]
stat = tuple(stat)
n_tune = stat.count(True)
n_draws = stat.count(False)
else:
# these may be wrong when KeyboardInterrupt happened, but they're better than nothing
n_tune = min(tune, len(trace))
n_draws = max(0, len(trace) - n_tune)

if discard_tuned_samples:
trace = trace[n_tune:]

# save metadata in SamplerReport
trace.report._n_tune = n_tune
trace.report._n_draws = n_draws
trace.report._t_sampling = t_sampling

n_chains = len(trace.chains)
_log.info(
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
f'({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) '
f'took {trace.report.t_sampling:.0f} seconds.'
)

if compute_convergence_checks:
if draws - tune < 100:
Expand Down
16 changes: 16 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ def test_sample_tune_len(self):
trace = pm.sample(draws=100, tune=50, cores=4)
assert len(trace) == 100

@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
@pytest.mark.parametrize("discard", [True, False])
def test_trace_report(self, step_cls, discard):
with self.model:
# add more variables, because stats are 2D with CompoundStep!
pm.Uniform('uni')
trace = pm.sample(
draws=100, tune=50, cores=1,
discard_tuned_samples=discard,
step=step_cls()
)
assert trace.report.n_tune == 50
assert trace.report.n_draws == 100
assert isinstance(trace.report.t_sampling, float)
pass

@pytest.mark.parametrize('cores', [1, 2])
def test_sampler_stat_tune(self, cores):
with self.model:
Expand Down

0 comments on commit 6c5254f

Please sign in to comment.