Skip to content

Commit

Permalink
Merge 8576e04 into 859691b
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Apr 7, 2017
2 parents 859691b + 8576e04 commit 0e649d0
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 12 deletions.
30 changes: 19 additions & 11 deletions pymc3/sampling.py
Expand Up @@ -324,19 +324,27 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
strace.setup(draws, chain, step.stats_dtypes)
else:
strace.setup(draws, chain)
for i in range(draws):
if i == tune:
step = stop_tuning(step)
if step.generates_stats:
point, states = step.step(point)
if strace.supports_sampler_stats:
strace.record(point, states)
try:
for i in range(draws):
if i == tune:
step = stop_tuning(step)
if step.generates_stats:
point, states = step.step(point)
if strace.supports_sampler_stats:
strace.record(point, states)
else:
strace.record(point)
else:
point = step.step(point)
strace.record(point)
else:
point = step.step(point)
strace.record(point)
yield strace
yield strace
except KeyboardInterrupt:
if hasattr(step, 'check_trace'):
step.check_trace(strace)
raise
else:
if hasattr(step, 'check_trace'):
step.check_trace(strace)


def _choose_backend(trace, chain, shortcuts=None, **kwds):
Expand Down
5 changes: 5 additions & 0 deletions pymc3/step_methods/compound.py
Expand Up @@ -30,3 +30,8 @@ def step(self, point):
for method in self.methods:
point = method.step(point)
return point

def check_trace(self, trace):
for method in self.methods:
if hasattr(method, 'check_trace'):
method.check_trace(trace)
63 changes: 62 additions & 1 deletion pymc3/step_methods/hmc/nuts.py
@@ -1,4 +1,5 @@
from collections import namedtuple
import warnings

from ..arraystep import Competence
from .base_hmc import BaseHMC
Expand All @@ -7,6 +8,7 @@

import numpy as np
import numpy.random as nr
from scipy import stats

__all__ = ['NUTS']

Expand Down Expand Up @@ -204,6 +206,63 @@ def competence(var):
return Competence.IDEAL
return Competence.INCOMPATIBLE

def check_trace(self, strace):
"""Print warnings for obviously problematic chains."""
n = len(strace)
chain = strace.chain

diverging = strace.get_sampler_stats('diverging')
if diverging.ndim == 2:
diverging = np.any(diverging, axis=0)

tuning = strace.get_sampler_stats('tune')
if tuning.ndim == 2:
tuning = np.any(tuning, axis=0)

accept = strace.get_sampler_stats('mean_tree_accept')
if accept.ndim == 2:
accept = np.mean(accept, axis=0)

depth = strace.get_sampler_stats('depth')
if depth.ndim == 2:
depth = np.max(depth, axis=0)

n_samples = n - (~tuning).sum()

if n < 1000:
warnings.warn('Chain %s contains only %s samples.' % (chain, n))
if np.all(diverging):
warnings.warn('Chain %s contains only diverging samples. '
'The model is probably misspecified.' % chain)
if np.all(tuning):
warnings.warn('Step size tuning was enabled throughout the whole '
'trace. You might want to specify the number of '
'tuning steps.')
if np.any(diverging[~tuning]):
warnings.warn("Chain %s contains diverging samples after tuning. "
"If increasing `target_accept` doesn't help, "
"try to reparameterize." % chain)
if n_samples > 0:
depth_samples = depth[~tuning]
else:
depth_samples = depth[n // 2:]
if np.mean(depth_samples == self.max_treedepth) > 0.05:
warnings.warn('Chain %s reached the maximum tree depth. Increase '
'max_treedepth, increase target_accept or '
'reparameterize.' % chain)

mean_accept = np.mean(accept[~tuning])
target_accept = self.target_accept
# Try to find a reasonable interval for acceptable acceptance
# probabilities. Finding this was mostry trial and error.
n_bound = min(100, n)
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
if target_accept < lower or target_accept > upper:
warnings.warn('The acceptance probability in chain %s does not '
'match the target. It is %s, but should be close '
'to %s. Try to increase the number of tuning steps.'
% (chain, mean_accept, target_accept))

# A node in the NUTS tree that is at the far right or left of the tree
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy')
Expand All @@ -212,7 +271,9 @@ def competence(var):
Proposal = namedtuple("Proposal", "q, energy, p_accept")

# A subtree of the binary tree built by nuts.
Subtree = namedtuple("Subtree", "left, right, p_sum, proposal, log_size, accept_sum, n_proposals")
Subtree = namedtuple(
"Subtree",
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals")


class Tree(object):
Expand Down

0 comments on commit 0e649d0

Please sign in to comment.