Skip to content

Commit

Permalink
Merge pull request #1619 from venpopov/read_csv_as_brms_stanfit_v2
Browse files Browse the repository at this point in the history
Read csv as stanfit v2 (alternative version)
  • Loading branch information
paul-buerkner committed Mar 12, 2024
2 parents 701a699 + 8ec0b8c commit a31ad2a
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 37 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Expand Up @@ -16,7 +16,8 @@ Authors@R:
person("Mattan S.", "Ben-Shachar", role = c("ctb")),
person("Hayden", "Rabel", role = c("ctb")),
person("Simon C.", "Mills", role = c("ctb")),
person("Stephen", "Wild", role = c("ctb")))
person("Stephen", "Wild", role = c("ctb")),
person("Ven", "Popov", role = c("ctb")))
Depends:
R (>= 3.6.0),
Rcpp (>= 0.12.0),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Expand Up @@ -547,6 +547,7 @@ export(ranef)
export(rasym_laplace)
export(rbeta_binomial)
export(rdirichlet)
export(read_csv_as_stanfit)
export(recompile_model)
export(reloo)
export(rename_pars)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Expand Up @@ -19,6 +19,7 @@ the default priors, for their own objects building on brms. Thanks to Ven Popov
for helping with this. (#1604)
* No longer automatically canonicalize the Stan code if cmdstanr is used
as backend. (#1544)
* Export `read_csv_as_stanfit` thanks to Ven Popov. (#1619)
* Improve parameter class names in the `summary` output.
* Show histograms rather than densities in the `plot` method by default.
* Deprecate argument `N` in the `plot` method in favor of argument `nvariables`.
Expand Down
95 changes: 59 additions & 36 deletions R/backends.R
Expand Up @@ -290,22 +290,12 @@ fit_model <- function(model, backend, ...) {
} else {
stop2("Algorithm '", algorithm, "' is not supported.")
}
# not all metadata is not stored by read_csv_as_stanfit
metadata <- cmdstanr::read_cmdstan_csv(
out$output_files(), variables = "", sampler_diagnostics = ""

out <- read_csv_as_stanfit(
out$output_files(), variables = out$metadata()$variables,
model = model, exclude = exclude
)
# ensure that only relevant variables are read from CSV
variables <- repair_variable_names(metadata$metadata$variables)
variables <- unique(sub("\\[.+", "", variables))
variables <- setdiff(variables, exclude)
# temp fix for cmdstanr not recognizing the variable names it produces #1473
variables <- ifelse(variables == "lp_approx__", "log_g__", variables)
# transform into stanfit object for consistent output structure
out <- read_csv_as_stanfit(out$output_files(), variables = variables)
out <- repair_stanfit(out)
# allow updating the model without recompilation
attributes(out)$CmdStanModel <- model
attributes(out)$metadata <- metadata

if (empty_model) {
# allow correct updating of an 'empty' model
out@sim <- list()
Expand Down Expand Up @@ -648,14 +638,51 @@ file_refit_options <- function() {
# paste0(out, collapse = "\n")
# }

# read in stan CSVs via cmdstanr and repackage into a stanfit object
# efficient replacement of rstan::read_stan_csv
# @param files character vector of CSV files names where draws are stored
# @param variables character vectors of variables to extract draws for
# @param sampler_diagnostics character vectors of diagnostics to extract
# @return a stanfit object
read_csv_as_stanfit <- function(files, variables = NULL,
sampler_diagnostics = NULL) {
#' Read CmdStan CSV files as a brms-formatted stanfit object
#'
#' \code{read_csv_as_stanfit} is used internally to read CmdStan CSV files into a
#' \code{stanfit} object that is consistent with the structure of the fit slot of a
#' brmsfit object.
#'
#' @param files Character vector of CSV files names where draws are stored.
#' @param variables Character vector of variables to extract from the CSV files.
#' @param sampler_diagnostics Character vector of sampler diagnostics to extract.
#' @param model A compiled cmdstanr model object (optional). Provide this argument
#' if you want to allow updating the model without recompilation.
#' @param exclude Character vector of variables to exclude from the stanfit. Only
#' used when \code{variables} is also specified.
#'
#' @return A stanfit object consistent with the structure of the \code{fit}
#' slot of a brmsfit object.
#'
#' @examples
#' \dontrun{
#' # fit a model manually via cmdstanr
#' scode <- stancode(count ~ Trt, data = epilepsy)
#' sdata <- standata(count ~ Trt, data = epilepsy)
#' mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(scode))
#' stanfit <- mod$sample(data = sdata)
#'
#' # feed the Stan model back into brms
#' fit <- brm(count ~ Trt, data = epilepsy, empty = TRUE, backend = 'cmdstanr')
#' fit$fit <- read_csv_as_stanfit(stanfit$output_files(), model = mod)
#' fit <- rename_pars(fit)
#' summary(fit)
#' }
#'
#' @export
read_csv_as_stanfit <- function(files, variables = NULL, sampler_diagnostics = NULL,
model = NULL, exclude = "") {
require_package("cmdstanr")

if (!is.null(variables)) {
# ensure that only relevant variables are read from CSV
variables <- repair_variable_names(variables)
variables <- unique(sub("\\[.+", "", variables))
variables <- setdiff(variables, exclude)
# temp fix for cmdstanr not recognizing the variable names it produces #1473
variables <- ifelse(variables == "lp_approx__", "log_g__", variables)
}

csfit <- cmdstanr::read_cmdstan_csv(
files = files, variables = variables,
Expand All @@ -667,11 +694,7 @@ read_csv_as_stanfit <- function(files, variables = NULL,
model_name = gsub(".csv", "", basename(files[[1]]))

# @model_pars
svars <- csfit$metadata$stan_variables
if (!is.null(variables)) {
variables_main <- unique(gsub("\\[.*\\]", "", variables))
svars <- intersect(variables_main, svars)
}
svars <- variables %||% csfit$metadata$stan_variables
if ("lp__" %in% svars) {
svars <- c(setdiff(svars, "lp__"), "lp__")
}
Expand Down Expand Up @@ -767,10 +790,6 @@ read_csv_as_stanfit <- function(files, variables = NULL,

fnames_oi <- colnames(samples)

colnames(samples) <- gsub("\\[", ".", colnames(samples))
colnames(samples) <- gsub("\\]", "", colnames(samples))
colnames(samples) <- gsub("\\,", ".", colnames(samples))

# split samples into chains
samples <- split(samples, chain_ids)
names(samples) <- NULL
Expand Down Expand Up @@ -808,16 +827,16 @@ read_csv_as_stanfit <- function(files, variables = NULL,
idx_samples <- (n_iter_warmup + 1):(n_iter_warmup + n_iter_sample)

for (i in seq_along(samples)) {
m <- colMeans(samples[[i]][idx_samples, , drop=FALSE])
m <- colMeans(samples[[i]][idx_samples, , drop = FALSE])
rownames(samples[[i]]) <- seq_rows(samples[[i]])
attr(samples[[i]], "sampler_params") <- diagnostics[[i]][rstan_diagn_order]
rownames(attr(samples[[i]], "sampler_params")) <- seq_rows(diagnostics[[i]])

# reformat back to text
if (is_equal(sampler_t, "NUTS(dense_e)")) {
mmatrix_txt <- "\n# Elements of inverse mass matrix:\n# "
mmat <- paste0(apply(csfit$inv_metric[[i]], 1, paste0, collapse=", "),
collapse="\n# ")
mmat <- paste0(apply(csfit$inv_metric[[i]], 1, paste0, collapse = ", "),
collapse = "\n# ")
} else {
mmatrix_txt <- "\n# Diagonal elements of inverse mass matrix:\n# "
mmat <- paste0(csfit$inv_metric[[i]], collapse = ", ")
Expand Down Expand Up @@ -943,7 +962,7 @@ read_csv_as_stanfit <- function(files, variables = NULL,
sdate <- do.call(max, lapply(files, function(csv) file.info(csv)$mtime))
sdate <- format(sdate, "%a %b %d %X %Y")

new(
out <- new(
"stanfit",
model_name = model_name,
model_pars = svars,
Expand All @@ -956,4 +975,8 @@ read_csv_as_stanfit <- function(files, variables = NULL,
date = sdate, # not the time of sampling
.MISC = new.env(parent = emptyenv())
)

attributes(out)$metadata <- csfit
attributes(out)$CmdStanModel <- model
out
}
19 changes: 19 additions & 0 deletions man/brms-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 52 additions & 0 deletions man/read_csv_as_stanfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a31ad2a

Please sign in to comment.