diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index aeab733069..93495f3646 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -470,19 +470,20 @@ def _finalize(self, strace): """Print warnings for obviously problematic chains.""" self._chain_id = strace.chain - tuning = strace.get_sampler_stats('tune') - if tuning.ndim == 2: - tuning = np.any(tuning, axis=-1) - - accept = strace.get_sampler_stats('mean_tree_accept') - if accept.ndim == 2: - accept = np.mean(accept, axis=-1) - - depth = strace.get_sampler_stats('depth') - if depth.ndim == 2: - depth = np.max(depth, axis=-1) - - self._check_len(tuning) - self._check_depth(depth[~tuning]) - self._check_accept(accept[~tuning]) - self._check_divergence() + if strace.supports_sampler_stats: + tuning = strace.get_sampler_stats('tune') + if tuning.ndim == 2: + tuning = np.any(tuning, axis=-1) + + accept = strace.get_sampler_stats('mean_tree_accept') + if accept.ndim == 2: + accept = np.mean(accept, axis=-1) + + depth = strace.get_sampler_stats('depth') + if depth.ndim == 2: + depth = np.max(depth, axis=-1) + + self._check_len(tuning) + self._check_depth(depth[~tuning]) + self._check_accept(accept[~tuning]) + self._check_divergence() diff --git a/pymc3/tests/test_text_backend.py b/pymc3/tests/test_text_backend.py index 21fe89d915..0eec24574b 100644 --- a/pymc3/tests/test_text_backend.py +++ b/pymc3/tests/test_text_backend.py @@ -1,9 +1,23 @@ +import pymc3 as pm from pymc3.tests import backend_fixtures as bf from pymc3.backends import ndarray, text import pytest import theano +class TestTextSampling(object): + name = 'text-db' + + def test_supports_sampler_stats(self): + with pm.Model(): + pm.Normal("mu", mu=0, sd=1, shape=2) + db = text.Text(self.name) + pm.sample(20, tune=10, init=None, trace=db) + + def teardown_method(self): + bf.remove_file_or_directory(self.name) + + class TestText0dSampling(bf.SamplingTestCase): backend = text.Text name = 'text-db'