Skip to content

Commit

Permalink
add generate_inits function
Browse files Browse the repository at this point in the history
  • Loading branch information
venpopov committed Mar 14, 2024
1 parent 28329d7 commit d729b32
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 2 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ URL: https://mc-stan.org/cmdstanr/, https://discourse.mc-stan.org
BugReports: https://github.com/stan-dev/cmdstanr/issues
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE, r6 = FALSE)
SystemRequirements: CmdStan (https://mc-stan.org/users/interfaces/cmdstan)
Depends:
Expand All @@ -42,7 +42,8 @@ Imports:
processx (>= 3.5.0),
R6 (>= 2.4.0),
withr (>= 2.5.0),
rlang (>= 0.4.7)
rlang (>= 0.4.7),
glue
Suggests:
bayesplot,
ggplot2,
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ S3method(as_draws,CmdStanMCMC)
S3method(as_draws,CmdStanMLE)
S3method(as_draws,CmdStanPathfinder)
S3method(as_draws,CmdStanVB)
S3method(generate_inits,CmdStanMCMC)
S3method(generate_inits,character)
S3method(generate_inits,draws)
export(as_cmdstan_fit)
export(as_draws)
export(as_mcmc.list)
Expand All @@ -19,6 +22,7 @@ export(cmdstan_version)
export(cmdstanr_example)
export(draws_to_csv)
export(eng_cmdstan)
export(generate_inits)
export(install_cmdstan)
export(num_threads)
export(print_example_program)
Expand Down
216 changes: 216 additions & 0 deletions R/inits.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#' Generate initial values for Stan Models
#'
#' The `generate_inits()` methods generate a list of lists of initial values for
#' each chain to be used in initializing a model fit with Stan
#'
#' @name generate_inits
#' @param object An object from which to generate initial values
#' @param ... Additional arguments to be passed to the specific methods
#' @details The `generate_inits()` method is generic function to which specific
#' methods for different classes of objects can be written. In the `cmdstanr`
#' package, the following objects are supported:
#'
#' * A `CmdStanMCMC` object, which is the result of sampling from a Stan model with cmdstanr
#' * A vector of file paths to the CSV files containing the draws from a Stan model
#' * A draws object from the `posterior` package
#'
#' For these objects, the function specified in \code{FUN} is applied to the
#' draws to generate the inits. This can be very flexible - any function works
#' as long as it returns a scalar and can be applied to a vector.
#' @return A list of lists of initial values for each chain
#' @export
#' @examples
#' \dontrun{
#' # inits from a CmdStanMCMC object
#' stanfit <- cmdstanr::cmdstanr_example("logistic")
#' generate_inits(stanfit)
#' generate_inits(stanfit, FUN = mean)
#' generate_inits(stanfit, FUN = quantile, probs = 0.5)
#' generate_inits(stanfit, draws = "last")
#'
#' # inits from a vector of file paths
#' files <- stanfit$output_files()
#' generate_inits(files)
#'
#' # inits from a draws object
#' draws <- stanfit$draws()
#' generate_inits(draws)
#'
#' # warmup and then use the final draws for the inits of a separate sampling stage
#' warmup <- cmdstanr_example("logistic", parallel_chains = 4, iter_sampling = 0, save_warmup = T)
#' inits <- generate_inits(warmup, draws = "last")
#' mod <- cmdstan_model(exe_file = warmup$runset$exe_file())
#' fit <- mod$sample(warmup$data_file(),
#' parallel_chains = 4,
#' init = inits,
#' iter_warmup = 0,
#' inv_metric = warmup$inv_metric(matrix = FALSE),
#' step_size = warmup$metadata()$step_size_adaptation,
#' adapt_engaged = FALSE)
#'
#' # compare with standard fitting with combined warmup and sampling
#' fit_standard <- mod$sample(warmup$data_file(),
#' parallel_chains = 4)
#' }
generate_inits <- function(object, ...) {
UseMethod("generate_inits")
}

#' @rdname generate_inits
#' @param FUN A function to apply to the draws to generate the inits. Only used
#' if draws = "all" or "sampling". It should be a function name that takes a
#' vector as input and returns a scalar, such as mean or median. The function
#' will be applied to each parameter's draws to generate the inits. The
#' default is to sample 1 random draw from the posterior draws
#' @param variables A character vector of parameter names for which to generate
#' inits
#' @export
generate_inits.draws <- function(object, variables = NULL, FUN = sample1, ...) {
checkmate::assert_function(FUN)
checkmate::assert_character(variables, null.ok = TRUE)
draws <- posterior::as_draws_array(object)
checkmate::assert_scalar(
FUN(c(draws[,1,1]), ...),
.var.name = paste0('the return value of ', as.character(quote(FUN)))
)

# extract parameter information from draws
nchains <- length(dimnames(draws)$chain)
all_pars <- dimnames(draws)$variable
par_dims <- variable_dims(all_pars)
if (is.null(variables)) {
variables <- names(par_dims)
} else {
variables <- intersect(variables, names(par_dims))
}

# apply the function to the draws to select the inits
draws <- apply(draws, 2:3, FUN, ...)

# prepare init list
out <- vector('list', nchains)

for (i in 1:nchains) {
out[[i]] <- vector('list', length(variables))
names(out[[i]]) <- variables

# extract the draw for each parameter and store it in the proper format
for (par in variables) {
pattern <- paste0("^", par, "(\\[|$)")
idx <- grep(pattern, all_pars)
values <- draws[i,idx]
dims <- par_dims[[par]]
if (any(dims > 1)) {
out[[i]][[par]] <- array(values, dims)
} else {
out[[i]][[par]] <- as.numeric(values)
}
}
}

out
}

#' @param draws A character string. Either "last", "sampling" or "all". If
#' "last", only the last draw is used. If "sampling", all the draws from the
#' sampling phase are used. If "all", all the draws, including warmup, are
#' used.
#' @export
#' @rdname generate_inits
generate_inits.CmdStanMCMC <- function(object, variables = NULL, FUN = sample1,
draws = "sampling", ...) {
draws <- match.arg(draws, c("last", "sampling", "all"))
pars <- names(object$runset$args$model_variables$parameters)
if (!is.null(variables)) {
pars <- intersect(variables, pars)
}

# get the draws array
if (draws == "last") {
draws <- read_last_draws(object$output_files())
dimnames(draws)$variable <- object$metadata()$model_params
} else {
draws <- object$draws(variables = pars, inc_warmup = draws == "all",
format = "draws_array")
}

generate_inits(draws, FUN = FUN, variables = pars, ...)
}

#' @export
#' @rdname generate_inits
generate_inits.character <- function(object, variables = NULL, FUN = sample1,
draws = "sampling", ...) {
checkmate::assert_file_exists(object)
checkmate::assert_function(FUN)
draws <- match.arg(draws, c("sampling", "all"))
stanfit <- as_cmdstan_fit(object, format = "draws_array")
generate_inits(stanfit, variables = variables, FUN = FUN, draws = draws, ...)
}



sample1 <- function(x) {
if (length(x) == 1) {
return(x)
}
base::sample(x, size = 1)
}


# efficiently extract the last complete draw recorded in a cmdstan csv file
# @param csv_file A character string of the file path
# @param par_names A logical. If TRUE, the parameter names are included.
# @return A draws array
read_last_draws <- function(csv_files, par_names = FALSE) {
checkmate::assert_file_exists(csv_files)
checkmate::assert_logical(par_names)

out <- vector("list", length(csv_files))
for (i in seq_along(csv_files)) {
file <- csv_files[i]
tmpfile <- tempfile()
cmd <- glue::glue("tail -12 {file} | grep \"^[0-9-]\" | tail -2 > {tmpfile}")
switch(.Platform$OS.type,
windows = shell(cmd),
unix = system(cmd))
lines <- readLines(tmpfile)

if (length(lines) == 0) {
stop("No draws found in ", csv_files[i], call. = FALSE)
}

lines <- strsplit(lines, ",")
if (length(lines[[2]]) < length(lines[[1]])) {
message("The last draw is incomplete. The last complete draw will be used.")
res <- lines[[1]]
} else {
res <- lines[[2]]
}
out[[i]] <- as.data.frame(t(as.numeric(unlist(res))))
}

out <- do.call(posterior::as_draws_array, list(out))

if (par_names) {
tmpfile <- tempfile()
cmd <- glue::glue('grep \"^[a-zA-Z]\" {csv_files[i]} > {tmpfile}')
switch(.Platform$OS.type,
windows = shell(cmd),
unix = system(cmd))
pars <- readLines(tmpfile)
if (length(pars) == 0) {
message("No parameter names found in ", csv_files[i])
} else if (length(pars) > 1) {
stop("Could not identify the parameter names in ", csv_files[i], call. = FALSE)
} else {
pars <- strsplit(pars, ",")[[1]]
dimnames(out)$variable <- repair_variable_names(pars)
}
}

# remove diagnostics
out <- out[,,-c(2:7)]

out
}
103 changes: 103 additions & 0 deletions man/generate_inits.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions tests/testthat/test-inits.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
test_that("generate inits works", {
fit_mcmc <- cmdstanr_example("logistic", chains = 2)
inits1 <- generate_inits(fit_mcmc)
inits2 <- generate_inits(fit_mcmc, draws = "last")
inits3 <- generate_inits(fit_mcmc, FUN = median)
inits4 <- generate_inits(fit_mcmc, FUN = quantile, probs = 0.5)

draws <- fit_mcmc$draws()
inits5 <- generate_inits(draws)
inits6 <- generate_inits(draws, variables = c('beta'))

files <- fit_mcmc$output_files()
inits7 <- generate_inits(files)

expect_length(inits1, 2)
expect_length(inits2, 2)
expect_length(inits3, 2)
expect_length(inits4, 2)
expect_length(inits5, 2)
expect_length(inits6, 2)
expect_length(inits7, 2)

expect_equal(names(inits1[[1]]), c('alpha','beta'))
expect_equal(names(inits2[[1]]), c('alpha','beta'))
expect_equal(names(inits3[[1]]), c('alpha','beta'))
expect_equal(names(inits4[[1]]), c('alpha','beta'))
expect_equal(names(inits5[[1]]), c('lp__','alpha','beta','log_lik'))
expect_equal(names(inits6[[1]]), c('beta'))
expect_equal(names(inits5[[1]]), c('lp__','alpha','beta','log_lik'))

dims <- variable_dims(fit_mcmc$metadata()$variables)
expect_equal(length(inits5[[1]]$alpha), dims$alpha)
expect_equal(length(inits5[[1]]$beta), dims$beta)
})

0 comments on commit d729b32

Please sign in to comment.