Skip to content

Commit

Permalink
use better ess in neff_ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Nov 29, 2022
1 parent b7b27a0 commit e82b02f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,6 @@ importFrom(bayesplot,log_posterior)
importFrom(bayesplot,neff_ratio)
importFrom(bayesplot,nuts_params)
importFrom(bayesplot,pp_check)
importFrom(bayesplot,rhat)
importFrom(bayesplot,theme_default)
importFrom(bridgesampling,bayes_factor)
importFrom(bridgesampling,bridge_sampler)
Expand Down Expand Up @@ -640,6 +639,7 @@ importFrom(posterior,nchains)
importFrom(posterior,ndraws)
importFrom(posterior,niterations)
importFrom(posterior,nvariables)
importFrom(posterior,rhat)
importFrom(posterior,subset_draws)
importFrom(posterior,summarize_draws)
importFrom(posterior,variables)
Expand Down
12 changes: 7 additions & 5 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ nuts_params.brmsfit <- function(object, pars = NULL, ...) {
}

#' @rdname diagnostic-quantities
#' @importFrom bayesplot rhat
#' @importFrom posterior rhat
#' @export rhat
#' @export
rhat.brmsfit <- function(object, pars = NULL, ...) {
contains_draws(object)
# bayesplot uses outdated rhat code from rstan
# bayesplot::rhat(object$fit, pars = pars, ...)
draws <- as_draws_array(object, variable = pars, ...)
tmp <- posterior::summarize_draws(draws, rhat = posterior::rhat)
tmp <- posterior::summarise_draws(draws, rhat = posterior::rhat)
rhat <- tmp$rhat
names(rhat) <- tmp$variable
rhat
Expand All @@ -77,9 +77,11 @@ neff_ratio.brmsfit <- function(object, pars = NULL, ...) {
# bayesplot uses outdated ess code from rstan
# bayesplot::neff_ratio(object$fit, pars = pars, ...)
draws <- as_draws_array(object, variable = pars, ...)
# currently uses ess_bulk as ess estimate for the central tendency
tmp <- posterior::summarize_draws(draws, ess = posterior::ess_bulk)
ess <- tmp$ess
tmp <- posterior::summarise_draws(
draws, ess_bulk = posterior::ess_bulk, ess_tail = posterior::ess_tail
)
# min of ess_bulk and ess_tail mimics definition of posterior::rhat.default
ess <- matrixStats::rowMins(cbind(tmp$ess_bulk, tmp$ess_tail))
names(ess) <- tmp$variable
ess / ndraws(draws)
}
Expand Down

0 comments on commit e82b02f

Please sign in to comment.