Skip to content

Commit

Permalink
Fix sampler stats error in NUTS
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Jul 19, 2017
1 parent 99e3cd3 commit 21dd6c2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
33 changes: 17 additions & 16 deletions pymc3/step_methods/hmc/nuts.py
Expand Up @@ -476,19 +476,20 @@ def _finalize(self, strace):
"""Print warnings for obviously problematic chains."""
self._chain_id = strace.chain

tuning = strace.get_sampler_stats('tune')
if tuning.ndim == 2:
tuning = np.any(tuning, axis=-1)

accept = strace.get_sampler_stats('mean_tree_accept')
if accept.ndim == 2:
accept = np.mean(accept, axis=-1)

depth = strace.get_sampler_stats('depth')
if depth.ndim == 2:
depth = np.max(depth, axis=-1)

self._check_len(tuning)
self._check_depth(depth[~tuning])
self._check_accept(accept[~tuning])
self._check_divergence()
if strace.supports_sampler_stats:
tuning = strace.get_sampler_stats('tune')
if tuning.ndim == 2:
tuning = np.any(tuning, axis=-1)

accept = strace.get_sampler_stats('mean_tree_accept')
if accept.ndim == 2:
accept = np.mean(accept, axis=-1)

depth = strace.get_sampler_stats('depth')
if depth.ndim == 2:
depth = np.max(depth, axis=-1)

self._check_len(tuning)
self._check_depth(depth[~tuning])
self._check_accept(accept[~tuning])
self._check_divergence()
14 changes: 14 additions & 0 deletions pymc3/tests/test_text_backend.py
@@ -1,9 +1,23 @@
import pymc3 as pm
from pymc3.tests import backend_fixtures as bf
from pymc3.backends import ndarray, text
import pytest
import theano


class TestTextSampling(object):
name = 'text-db'

def test_supports_sampler_stats(self):
with pm.Model():
pm.Normal("mu", mu=0, sd=1, shape=2)
db = text.Text(self.name)
pm.sample(20, tune=10, init=None, trace=db)

def teardown_method(self):
bf.remove_file_or_directory(self.name)


class TestText0dSampling(bf.SamplingTestCase):
backend = text.Text
name = 'text-db'
Expand Down

0 comments on commit 21dd6c2

Please sign in to comment.