Skip to content

Commit

Permalink
add extractor methods for CmdStanMCMC objects (from CmdStanR) (#227)
Browse files Browse the repository at this point in the history
* add extractor methods for CmdStanMCMC objects

* fix test

* Update bayesplot-extractors.R

* Update NEWS.md
  • Loading branch information
jgabry committed Aug 7, 2020
1 parent f4ed652 commit e54367b
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 18 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

S3method("[",neff_ratio)
S3method("[",rhat)
S3method(log_posterior,CmdStanMCMC)
S3method(log_posterior,stanfit)
S3method(log_posterior,stanreg)
S3method(neff_ratio,CmdStanMCMC)
S3method(neff_ratio,stanfit)
S3method(neff_ratio,stanreg)
S3method(nuts_params,CmdStanMCMC)
S3method(nuts_params,list)
S3method(nuts_params,stanfit)
S3method(nuts_params,stanreg)
Expand All @@ -15,6 +18,7 @@ S3method(pp_check,default)
S3method(print,bayesplot_function_list)
S3method(print,bayesplot_grid)
S3method(print,bayesplot_scheme)
S3method(rhat,CmdStanMCMC)
S3method(rhat,stanfit)
S3method(rhat,stanreg)
export(abline_01)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
* Items for next release go here
-->

* CmdStanMCMC objects (from CmdStanR) can now be used with extractor
functions `nuts_params()`, `log_posterior()`, `rhat()`, and
`neff_ratio()`. (#227)

* Added missing `facet_args` argument to `mcmc_rank_overlay()`. (#221, @hhau)
* Size of points and interval lines can set in
`mcmc_intervals(..., outer_size, inner_size, point_size)`. (#215, #228, #229)
Expand Down
70 changes: 64 additions & 6 deletions R/bayesplot-extractors.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' Extract quantities needed for plotting from model objects
#'
#' Generics and methods for extracting quantities needed for plotting from
#' various types of model objects. Currently methods are only provided for
#' stanfit (**rstan**) and stanreg (**rstanarm**) objects, but adding new
#' methods should be relatively straightforward.
#' various types of model objects. Currently methods are provided for stanfit
#' (**rstan**), CmdStanMCMC (**cmdstanr**), and stanreg (**rstanarm**) objects,
#' but adding new methods should be relatively straightforward.
#'
#' @name bayesplot-extractors
#' @param object The object to use.
Expand Down Expand Up @@ -87,7 +87,8 @@ log_posterior.stanfit <- function(object, inc_warmup = FALSE, ...) {
...)
lp <- lapply(lp, as.array)
lp <- set_names(reshape2::melt(lp), c("Iteration", "Value", "Chain"))
validate_df_classes(lp, c("integer", "numeric", "integer"))
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
c("integer", "integer", "numeric"))
}

#' @rdname bayesplot-extractors
Expand All @@ -98,11 +99,22 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) {
log_posterior.stanfit(object$stanfit, inc_warmup = inc_warmup, ...)
}

#' @rdname bayesplot-extractors
#' @export
#' @method log_posterior CmdStanMCMC
log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) {
lp <- object$draws("lp__", inc_warmup = inc_warmup)
lp <- reshape2::melt(lp)
lp$variable <- NULL
lp <- dplyr::rename_with(lp, capitalize_first)
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
c("integer", "integer", "numeric"))
}


#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params stanfit
#'
nuts_params.stanfit <-
function(object,
pars = NULL,
Expand Down Expand Up @@ -153,7 +165,23 @@ nuts_params.list <- function(object, pars = NULL, ...) {

out <- reshape2::melt(object)
out <- set_names(out, c("Iteration", "Parameter", "Value", "Chain"))
validate_df_classes(out, c("integer", "factor", "numeric", "integer"))
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
c("integer", "integer", "factor", "numeric"))
}

#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params CmdStanMCMC
nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) {
arr <- object$sampler_diagnostics()
if (!is.null(pars)) {
arr <- arr[,, pars]
}
out <- reshape2::melt(arr)
colnames(out)[colnames(out) == "variable"] <- "parameter"
out <- dplyr::rename_with(out, capitalize_first)
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
c("integer", "integer", "factor", "numeric"))
}


Expand Down Expand Up @@ -188,6 +216,17 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
r[!names(r) %in% c("mean_PPD", "log-posterior")]
}

#' @rdname bayesplot-extractors
#' @export
#' @method rhat CmdStanMCMC
rhat.CmdStanMCMC <- function(object, pars = NULL, ...) {
.rhat <- utils::getFromNamespace("rhat", "posterior")
s <- object$summary(pars, rhat = .rhat)[, c("variable", "rhat")]
r <- setNames(s$rhat, s$variable)
r <- validate_rhat(r)
r[!names(r) %in% "lp__"]
}


#' @rdname bayesplot-extractors
#' @export
Expand Down Expand Up @@ -223,6 +262,18 @@ neff_ratio.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
ratio[!names(ratio) %in% c("mean_PPD", "log-posterior")]
}

#' @rdname bayesplot-extractors
#' @export
#' @method neff_ratio CmdStanMCMC
neff_ratio.CmdStanMCMC <- function(object, pars = NULL, ...) {
s <- object$summary(pars, "n_eff" = "ess_basic")[, c("variable", "n_eff")]
ess <- setNames(s$n_eff, s$variable)
tss <- prod(dim(object$draws())[1:2])
ratio <- ess / tss
ratio <- validate_neff_ratio(ratio)
ratio[!names(ratio) %in% "lp__"]
}


# internals ---------------------------------------------------------------

Expand All @@ -245,3 +296,10 @@ validate_df_classes <- function(x, classes = character()) {
}
x
}

# capitalize first letter in a string only
capitalize_first <- function(name) {
name <- tolower(name) # in case whole string is capitalized
substr(name, 1, 1) <- toupper(substr(name, 1, 1))
name
}
8 changes: 4 additions & 4 deletions R/mcmc-diagnostics-nuts.R
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ validate_nuts_data_frame <- function(x, lp) {
abort("NUTS parameters should be in a data frame.")
}

valid_cols <- c("Iteration", "Parameter", "Value", "Chain")
if (!identical(colnames(x), valid_cols)) {
valid_cols <- sort(c("Iteration", "Parameter", "Value", "Chain"))
if (!identical(sort(colnames(x)), valid_cols)) {
abort(paste(
"NUTS parameter data frame must have columns:",
paste(valid_cols, collapse = ", ")
Expand All @@ -529,8 +529,8 @@ validate_nuts_data_frame <- function(x, lp) {
abort("lp should be in a data frame.")
}

valid_lp_cols <- c("Iteration", "Value", "Chain")
if (!identical(colnames(lp), valid_lp_cols)) {
valid_lp_cols <- sort(c("Iteration", "Value", "Chain"))
if (!identical(sort(colnames(lp)), valid_lp_cols)) {
abort(paste(
"lp data frame must have columns:",
paste(valid_lp_cols, collapse = ", ")
Expand Down
18 changes: 15 additions & 3 deletions man/bayesplot-extractors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 29 additions & 2 deletions tests/testthat/test-extractors.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test_that("all nuts_params methods identical", {

test_that("nuts_params.stanreg returns correct structure", {
np <- nuts_params(fit)
expect_identical(colnames(np), c("Iteration", "Parameter", "Value", "Chain"))
expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value"))

np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog",
"divergent", "energy"), "__")
Expand All @@ -54,7 +54,7 @@ test_that("nuts_params.stanreg returns correct structure", {

test_that("log_posterior.stanreg returns correct structure", {
lp <- log_posterior(fit)
expect_identical(colnames(lp), c("Iteration", "Value", "Chain"))
expect_identical(colnames(lp), c("Chain", "Iteration", "Value"))
expect_equal(length(unique(lp$Iteration)), floor(ITER / 2))
expect_equal(length(unique(lp$Chain)), CHAINS)
})
Expand Down Expand Up @@ -100,3 +100,30 @@ test_that("neff_ratio.stanreg returns correct structure", {
ans2 <- summary(fit, pars = c("wt", "sigma"))[, "n_eff"] / denom
expect_equal(ratio2, ans2, tol = 0.001)
})

test_that("cmdstanr methods work", {
skip_on_cran()
skip_if_not_installed("cmdstanr")

fit <- cmdstanr::cmdstanr_example("logistic", iter_sampling = 500, chains = 2)
np <- nuts_params(fit)
np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog",
"divergent", "energy"), "__")
expect_identical(levels(np$Parameter), np_names)
expect_equal(range(np$Iteration), c(1, 500))
expect_equal(range(np$Chain), c(1, 2))
expect_true(all(np$Value[np$Parameter == "divergent__"] == 0))

lp <- log_posterior(fit)
expect_named(lp, c("Chain", "Iteration", "Value"))
expect_equal(range(np$Chain), c(1, 2))
expect_equal(range(np$Iteration), c(1, 500))

r <- rhat(fit)
expect_named(r, c("alpha", "beta[1]", "beta[2]", "beta[3]"))
expect_true(all(round(r) == 1))

ratio <- neff_ratio(fit)
expect_named(ratio, c("alpha", "beta[1]", "beta[2]", "beta[3]"))
expect_true(all(ratio > 0))
})
4 changes: 2 additions & 2 deletions tests/testthat/test-mcmc-nuts.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ test_that("validate_nuts_data_frame throws errors", {
)
expect_error(
validate_nuts_data_frame(data.frame(Iteration = 1, apple = 2)),
"NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain"
"NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value"
)
expect_error(
validate_nuts_data_frame(np, as.matrix(lp)),
Expand All @@ -69,7 +69,7 @@ test_that("validate_nuts_data_frame throws errors", {
colnames(lp2)[3] <- "Chains"
expect_error(
validate_nuts_data_frame(np, lp2),
"lp data frame must have columns: Iteration, Value, Chain"
"lp data frame must have columns: Chain, Iteration, Value"
)

lp2 <- subset(lp, Chain %in% 1:2)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-mcmc-scatter-and-parcoord.R
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ test_that("mcmc_parcoord throws correct warnings and errors", {

expect_error(
mcmc_parcoord(post, np = np[, -1]),
"NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain",
"NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value",
fixed = TRUE
)

Expand Down

0 comments on commit e54367b

Please sign in to comment.