Skip to content

mcmc_pairs doesn't show max_treedepth warnings for rstan models #281

Closed
@dmphillippo

Description

@dmphillippo

mcmc_pairs doesn't behave in the same way as rstan::pairs with respect to highlighting max_treedepth warnings. As a result, setting the max_treedepth option in mcmc_pairs to the same value passed to rstan does not show any iterations exceeding max_treedepth.

This is because mcmc_pairs only flags iterations where treedepth__ > max_treedepth:

gt_max_td <- (dplyr::filter(np, UQ(param) == "treedepth__") %>% pull(UQ(val))) > max_treedepth

whereas rstan::pairs flags iterations where treedepth__ >= max_treedepth:
https://github.com/stan-dev/rstan/blob/71ab1409531c2c5e412635482fee545c77e2a070/rstan/rstan/R/pairs.R#L78

The same condition is used in rstan to provide the max_treedepth warning after fitting a model:
https://github.com/stan-dev/rstan/blob/da2fc9c079534a82d3d26adda51ad17bf22f5e2b/rstan/rstan/R/check_hmc_diagnostics.R#L49

For example, here's a trivial model fitted in rstan with max_treedepth = 2 to get lots of max_treedepth warnings

m <- stan_model(model_code = 'parameters {real y;} model {y ~ normal(0,1);}')
mfit <- sampling(m, iter = 1000, control = list(max_treedepth = 2))
#> Warning messages:
#> 1: There were 777 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 2. See
#> http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded 
#> 2: Examine the pairs() plot to diagnose sampling problems

mcmc_pairs gives us:

mcmc_pairs(mfit, 
           np = nuts_params(mfit), 
           lp = log_posterior(mfit), 
           max_treedepth = mfit@stan_args[[1]]$control$max_treedepth,
           condition = pairs_condition(nuts = "accept_stat__"))

image

Whereas rstan::pairs gives us:

pairs(mfit, pars = c("y", "lp__"))

image

This also affects downstream packages using bayesplot::mcmc_pairs like rstanarm and multinma, which set max_treedepth in mcmc_pairs from the stan control argument in this manner, and as a result don't seem to plot max_treedepth warnings.

I think a simple fix would be to change > to >= here:

gt_max_td <- (dplyr::filter(np, UQ(param) == "treedepth__") %>% pull(UQ(val))) > max_treedepth

I'm happy to create a PR, if you agree.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions