Skip to content

Commit

Permalink
Fix error in warn_treedepth when using multiple NUTS sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 1, 2024
1 parent 6c6fd13 commit 9bf2190
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pymc/stats/convergence.py
Expand Up @@ -164,7 +164,7 @@ def warn_treedepth(idata: arviz.InferenceData) -> list[SamplerWarning]:

warnings = []
for c in rmtd.chain:
if sum(rmtd.sel(chain=c)) / rmtd.sizes["draw"] > 0.05:
if (rmtd.sel(chain=c).mean("draw") > 0.05).any():
warnings.append(
SamplerWarning(
WarningType.TREEDEPTH,
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/nuts.py
Expand Up @@ -198,7 +198,7 @@ def _hamiltonian_step(self, start, p0, step_size):

if divergence_info or turning:
break
else:
else: # no-break
reached_max_treedepth = not self.tune

stats = tree.stats()
Expand Down
16 changes: 16 additions & 0 deletions tests/stats/test_convergence.py
Expand Up @@ -42,6 +42,22 @@ def test_warn_treedepth():
assert "Chain 1 reached the maximum tree depth" in warns[0].message


def test_warn_treedepth_multiple_samplers():
"""Check we handle cases when sampling with multiple NUTS samplers, each of which reports max_treedepth."""
max_treedepth = np.zeros((3, 2, 2), dtype=bool)
max_treedepth[0, 0, 0] = True
max_treedepth[2, 1, 1] = True
idata = arviz.from_dict(
sample_stats={
"reached_max_treedepth": max_treedepth,
}
)
warns = convergence.warn_treedepth(idata)
assert len(warns) == 2
assert "Chain 0 reached the maximum tree depth" in warns[0].message
assert "Chain 2 reached the maximum tree depth" in warns[1].message


def test_log_warning_stats(caplog):
s1 = dict(warning="Temperature too low!")
s2 = dict(warning="Temperature too high!")
Expand Down

0 comments on commit 9bf2190

Please sign in to comment.