Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add extractor methods for CmdStanMCMC objects (from CmdStanR) #227

Merged
merged 4 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

S3method("[",neff_ratio)
S3method("[",rhat)
S3method(log_posterior,CmdStanMCMC)
S3method(log_posterior,stanfit)
S3method(log_posterior,stanreg)
S3method(neff_ratio,CmdStanMCMC)
S3method(neff_ratio,stanfit)
S3method(neff_ratio,stanreg)
S3method(nuts_params,CmdStanMCMC)
S3method(nuts_params,list)
S3method(nuts_params,stanfit)
S3method(nuts_params,stanreg)
Expand All @@ -15,6 +18,7 @@ S3method(pp_check,default)
S3method(print,bayesplot_function_list)
S3method(print,bayesplot_grid)
S3method(print,bayesplot_scheme)
S3method(rhat,CmdStanMCMC)
S3method(rhat,stanfit)
S3method(rhat,stanreg)
export(abline_01)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
* Items for next release go here
-->

* CmdStanMCMC objects (from CmdStanR) can now be used with extractor
functions `nuts_params()`, `log_posterior()`, `rhat()`, and
`neff_ratio()`. (#227)

* Added missing `facet_args` argument to `mcmc_rank_overlay()`. (#221, @hhau)


Expand Down
70 changes: 64 additions & 6 deletions R/bayesplot-extractors.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' Extract quantities needed for plotting from model objects
#'
#' Generics and methods for extracting quantities needed for plotting from
#' various types of model objects. Currently methods are only provided for
#' stanfit (**rstan**) and stanreg (**rstanarm**) objects, but adding new
#' methods should be relatively straightforward.
#' various types of model objects. Currently methods are provided for stanfit
#' (**rstan**), CmdStanMCMC (**cmdstanr**), and stanreg (**rstanarm**) objects,
#' but adding new methods should be relatively straightforward.
#'
#' @name bayesplot-extractors
#' @param object The object to use.
Expand Down Expand Up @@ -87,7 +87,8 @@ log_posterior.stanfit <- function(object, inc_warmup = FALSE, ...) {
...)
lp <- lapply(lp, as.array)
lp <- set_names(reshape2::melt(lp), c("Iteration", "Value", "Chain"))
validate_df_classes(lp, c("integer", "numeric", "integer"))
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
c("integer", "integer", "numeric"))
}

#' @rdname bayesplot-extractors
Expand All @@ -98,11 +99,22 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) {
log_posterior.stanfit(object$stanfit, inc_warmup = inc_warmup, ...)
}

#' @rdname bayesplot-extractors
#' @export
#' @method log_posterior CmdStanMCMC
log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) {
lp <- object$draws("lp__", inc_warmup = inc_warmup)
lp <- reshape2::melt(lp)
lp$variable <- NULL
lp <- dplyr::rename_with(lp, capitalize_first)
validate_df_classes(lp[, c("Chain", "Iteration", "Value")],
c("integer", "integer", "numeric"))
}


#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params stanfit
#'
nuts_params.stanfit <-
function(object,
pars = NULL,
Expand Down Expand Up @@ -153,7 +165,23 @@ nuts_params.list <- function(object, pars = NULL, ...) {

out <- reshape2::melt(object)
out <- set_names(out, c("Iteration", "Parameter", "Value", "Chain"))
validate_df_classes(out, c("integer", "factor", "numeric", "integer"))
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
c("integer", "integer", "factor", "numeric"))
}

#' @rdname bayesplot-extractors
#' @export
#' @method nuts_params CmdStanMCMC
nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) {
arr <- object$sampler_diagnostics()
if (!is.null(pars)) {
arr <- arr[,, pars]
}
out <- reshape2::melt(arr)
colnames(out)[colnames(out) == "variable"] <- "parameter"
out <- dplyr::rename_with(out, capitalize_first)
validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")],
c("integer", "integer", "factor", "numeric"))
}


Expand Down Expand Up @@ -188,6 +216,17 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
r[!names(r) %in% c("mean_PPD", "log-posterior")]
}

#' @rdname bayesplot-extractors
#' @export
#' @method rhat CmdStanMCMC
rhat.CmdStanMCMC <- function(object, pars = NULL, ...) {
.rhat <- utils::getFromNamespace("rhat", "posterior")
s <- object$summary(pars, rhat = .rhat)[, c("variable", "rhat")]
r <- setNames(s$rhat, s$variable)
r <- validate_rhat(r)
r[!names(r) %in% "lp__"]
}


#' @rdname bayesplot-extractors
#' @export
Expand Down Expand Up @@ -223,6 +262,18 @@ neff_ratio.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) {
ratio[!names(ratio) %in% c("mean_PPD", "log-posterior")]
}

#' @rdname bayesplot-extractors
#' @export
#' @method neff_ratio CmdStanMCMC
neff_ratio.CmdStanMCMC <- function(object, pars = NULL, ...) {
s <- object$summary(pars, "n_eff" = "ess_basic")[, c("variable", "n_eff")]
ess <- setNames(s$n_eff, s$variable)
tss <- prod(dim(object$draws())[1:2])
ratio <- ess / tss
ratio <- validate_neff_ratio(ratio)
ratio[!names(ratio) %in% "lp__"]
}


# internals ---------------------------------------------------------------

Expand All @@ -245,3 +296,10 @@ validate_df_classes <- function(x, classes = character()) {
}
x
}

# capitalize first letter in a string only
capitalize_first <- function(name) {
name <- tolower(name) # in case whole string is capitalized
substr(name, 1, 1) <- toupper(substr(name, 1, 1))
name
}
8 changes: 4 additions & 4 deletions R/mcmc-diagnostics-nuts.R
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ validate_nuts_data_frame <- function(x, lp) {
abort("NUTS parameters should be in a data frame.")
}

valid_cols <- c("Iteration", "Parameter", "Value", "Chain")
if (!identical(colnames(x), valid_cols)) {
valid_cols <- sort(c("Iteration", "Parameter", "Value", "Chain"))
if (!identical(sort(colnames(x)), valid_cols)) {
abort(paste(
"NUTS parameter data frame must have columns:",
paste(valid_cols, collapse = ", ")
Expand All @@ -529,8 +529,8 @@ validate_nuts_data_frame <- function(x, lp) {
abort("lp should be in a data frame.")
}

valid_lp_cols <- c("Iteration", "Value", "Chain")
if (!identical(colnames(lp), valid_lp_cols)) {
valid_lp_cols <- sort(c("Iteration", "Value", "Chain"))
if (!identical(sort(colnames(lp)), valid_lp_cols)) {
abort(paste(
"lp data frame must have columns:",
paste(valid_lp_cols, collapse = ", ")
Expand Down
18 changes: 15 additions & 3 deletions man/bayesplot-extractors.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 29 additions & 2 deletions tests/testthat/test-extractors.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test_that("all nuts_params methods identical", {

test_that("nuts_params.stanreg returns correct structure", {
np <- nuts_params(fit)
expect_identical(colnames(np), c("Iteration", "Parameter", "Value", "Chain"))
expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value"))

np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog",
"divergent", "energy"), "__")
Expand All @@ -54,7 +54,7 @@ test_that("nuts_params.stanreg returns correct structure", {

test_that("log_posterior.stanreg returns correct structure", {
lp <- log_posterior(fit)
expect_identical(colnames(lp), c("Iteration", "Value", "Chain"))
expect_identical(colnames(lp), c("Chain", "Iteration", "Value"))
expect_equal(length(unique(lp$Iteration)), floor(ITER / 2))
expect_equal(length(unique(lp$Chain)), CHAINS)
})
Expand Down Expand Up @@ -100,3 +100,30 @@ test_that("neff_ratio.stanreg returns correct structure", {
ans2 <- summary(fit, pars = c("wt", "sigma"))[, "n_eff"] / denom
expect_equal(ratio2, ans2, tol = 0.001)
})

test_that("cmdstanr methods work", {
skip_on_cran()
skip_if_not_installed("cmdstanr")

fit <- cmdstanr::cmdstanr_example("logistic", iter_sampling = 500, chains = 2)
np <- nuts_params(fit)
np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog",
"divergent", "energy"), "__")
expect_identical(levels(np$Parameter), np_names)
expect_equal(range(np$Iteration), c(1, 500))
expect_equal(range(np$Chain), c(1, 2))
expect_true(all(np$Value[np$Parameter == "divergent__"] == 0))

lp <- log_posterior(fit)
expect_named(lp, c("Chain", "Iteration", "Value"))
expect_equal(range(np$Chain), c(1, 2))
expect_equal(range(np$Iteration), c(1, 500))

r <- rhat(fit)
expect_named(r, c("alpha", "beta[1]", "beta[2]", "beta[3]"))
expect_true(all(round(r) == 1))

ratio <- neff_ratio(fit)
expect_named(ratio, c("alpha", "beta[1]", "beta[2]", "beta[3]"))
expect_true(all(ratio > 0))
})
4 changes: 2 additions & 2 deletions tests/testthat/test-mcmc-nuts.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ test_that("validate_nuts_data_frame throws errors", {
)
expect_error(
validate_nuts_data_frame(data.frame(Iteration = 1, apple = 2)),
"NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain"
"NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value"
)
expect_error(
validate_nuts_data_frame(np, as.matrix(lp)),
Expand All @@ -69,7 +69,7 @@ test_that("validate_nuts_data_frame throws errors", {
colnames(lp2)[3] <- "Chains"
expect_error(
validate_nuts_data_frame(np, lp2),
"lp data frame must have columns: Iteration, Value, Chain"
"lp data frame must have columns: Chain, Iteration, Value"
)

lp2 <- subset(lp, Chain %in% 1:2)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-mcmc-scatter-and-parcoord.R
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ test_that("mcmc_parcoord throws correct warnings and errors", {

expect_error(
mcmc_parcoord(post, np = np[, -1]),
"NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain",
"NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value",
fixed = TRUE
)

Expand Down