Skip to content

Commit

Permalink
Merge 09b5f03 into 1144409
Browse files Browse the repository at this point in the history
  • Loading branch information
rok-cesnovar committed Sep 15, 2021
2 parents 1144409 + 09b5f03 commit 4bbe4c3
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 24 deletions.
15 changes: 9 additions & 6 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ CmdStanArgs <- R6::R6Class(
output_basename = NULL,
validate_csv = TRUE,
sig_figs = NULL,
opencl_ids = NULL) {
opencl_ids = NULL,
include_paths = NULL) {

self$model_name <- model_name
self$exe_file <- exe_file
Expand All @@ -59,7 +60,7 @@ CmdStanArgs <- R6::R6Class(
if (is.function(init)) {
init <- process_init_function(init, length(self$proc_ids), stan_file)
} else if (is.list(init) && !is.data.frame(init)) {
init <- process_init_list(init, length(self$proc_ids), stan_file)
init <- process_init_list(init, length(self$proc_ids), stan_file, include_paths)
}
self$init <- init
self$opencl_ids <- opencl_ids
Expand Down Expand Up @@ -767,8 +768,9 @@ validate_exe_file <- function(exe_file) {
#' @param init List of init lists.
#' @param num_procs Number of CmdStan processes.
#' @param stan_file Path to the Stan model file.
#' @param include_paths Folders with Stan files included in the Stan model file.
#' @return A character vector of file paths.
process_init_list <- function(init, num_procs, stan_file = NULL) {
process_init_list <- function(init, num_procs, stan_file = NULL, include_paths = NULL) {
if (!all(sapply(init, function(x) is.list(x) && !is.data.frame(x)))) {
stop("If 'init' is a list it must be a list of lists.", call. = FALSE)
}
Expand All @@ -782,7 +784,7 @@ process_init_list <- function(init, num_procs, stan_file = NULL) {
stan_file <- absolute_path(stan_file)
if (file.exists(stan_file)) {
missing_parameter_values <- list()
parameter_names <- names(model_variables(stan_file)$parameters)
parameter_names <- names(model_variables(stan_file, include_paths)$parameters)
for (i in seq_along(init)) {
is_parameter_value_supplied <- parameter_names %in% names(init[[i]])
if (!all(is_parameter_value_supplied)) {
Expand Down Expand Up @@ -832,8 +834,9 @@ process_init_list <- function(init, num_procs, stan_file = NULL) {
#' @param init Function generating a single list of initial values.
#' @param num_procs Number of CmdStan processes.
#' @param stan_file Path to the Stan model file.
#' @param include_paths Folders with Stan files included in the Stan model file.
#' @return A character vector of file paths.
process_init_function <- function(init, num_procs, stan_file = NULL) {
process_init_function <- function(init, num_procs, stan_file = NULL, include_paths = NULL) {
args <- formals(init)
if (is.null(args)) {
fn_test <- init()
Expand All @@ -849,7 +852,7 @@ process_init_function <- function(init, num_procs, stan_file = NULL) {
if (!is.list(fn_test) || is.data.frame(fn_test)) {
stop("If 'init' is a function it must return a single list.")
}
process_init_list(init_list, num_procs, stan_file)
process_init_list(init_list, num_procs, stan_file, include_paths)
}

#' Validate initial values
Expand Down
5 changes: 3 additions & 2 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ list_to_array <- function(x, name = NULL) {
#' required elements/Stan variables and to help differentiate between a
#' vector of length 1 and a scalar when genereting the JSON file. This
#' argument is ignored when a path to a data file is supplied for `data`.
#' @param include_paths Folders with Stan files included in the Stan model file.
#' @return Path to data file.
process_data <- function(data, stan_file = NULL) {
process_data <- function(data, stan_file = NULL, include_paths = NULL) {
if (length(data) == 0) {
data <- NULL
}
Expand All @@ -165,7 +166,7 @@ process_data <- function(data, stan_file = NULL) {
if (cmdstan_version() >= "2.27.0" && !is.null(stan_file)) {
stan_file <- absolute_path(stan_file)
if (file.exists(stan_file)) {
data_variables <- model_variables(stan_file)$data
data_variables <- model_variables(stan_file, include_paths)$data
is_data_supplied <- names(data_variables) %in% names(data)
if (!all(is_data_supplied)) {
missing <- names(data_variables[!is_data_supplied])
Expand Down
42 changes: 26 additions & 16 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,17 @@ CmdStanModel <- R6::R6Class(
private$precompile_cpp_options_ <- args$cpp_options %||% list()
private$precompile_stanc_options_ <- assert_valid_stanc_options(args$stanc_options) %||% list()
private$precompile_include_paths_ <- args$include_paths
private$include_paths_ <- args$include_paths
private$dir_ <- args$dir

if (compile) {
self$compile(...)
}
invisible(self)
},

include_paths = function() {
private$include_paths_
},
code = function() {
readLines(self$stan_file())
},
Expand Down Expand Up @@ -372,6 +375,7 @@ compile <- function(quiet = TRUE,
if (is.null(include_paths) && !is.null(private$precompile_include_paths_)) {
include_paths <- private$precompile_include_paths_
}
private$include_paths_ <- include_paths
if (is.null(dir) && !is.null(private$dir_)) {
dir <- absolute_path(private$dir_)
} else if (!is.null(dir)) {
Expand Down Expand Up @@ -564,7 +568,7 @@ variables <- function() {
stop("$variables() is only supported for CmdStan 2.27 or newer.", call. = FALSE)
}
if (is.null(private$variables_)) {
private$variables_ <- model_variables(self$stan_file())
private$variables_ <- model_variables(self$stan_file(), self$include_paths())
}
private$variables_
}
Expand Down Expand Up @@ -824,7 +828,7 @@ sample <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
data_file = process_data(data, self$stan_file()),
data_file = process_data(data, self$stan_file(), self$include_paths()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
Expand All @@ -833,7 +837,8 @@ sample <- function(data = NULL,
output_basename = output_basename,
sig_figs = sig_figs,
validate_csv = validate_csv,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options())
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
include_paths = self$include_paths()
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -960,15 +965,16 @@ sample_mpi <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
data_file = process_data(data, self$stan_file()),
data_file = process_data(data, self$stan_file(), self$include_paths()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
refresh = refresh,
output_dir = output_dir,
output_basename = output_basename,
validate_csv = validate_csv,
sig_figs = sig_figs
sig_figs = sig_figs,
include_paths = self$include_paths()
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan_mpi(mpi_cmd, mpi_args)
Expand Down Expand Up @@ -1066,15 +1072,16 @@ optimize <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
data_file = process_data(data, self$stan_file()),
data_file = process_data(data, self$stan_file(), self$include_paths()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
refresh = refresh,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options())
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
include_paths = self$include_paths()
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1177,15 +1184,16 @@ variational <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
data_file = process_data(data, self$stan_file()),
data_file = process_data(data, self$stan_file(), self$include_paths()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
refresh = refresh,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options())
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
include_paths = self$include_paths()
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1280,12 +1288,13 @@ generate_quantities <- function(fitted_params,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = seq_along(fitted_params_files),
data_file = process_data(data, self$stan_file()),
data_file = process_data(data, self$stan_file(), self$include_paths()),
seed = seed,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options())
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
include_paths = self$include_paths()
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1337,11 +1346,12 @@ diagnose_method <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
data_file = process_data(data, self$stan_file()),
data_file = process_data(data, self$stan_file(), self$include_paths()),
seed = seed,
init = init,
output_dir = output_dir,
output_basename = output_basename
output_basename = output_basename,
include_paths = self$include_paths()
)
runset <- CmdStanRun$new(args, procs)
runset$run_cmdstan()
Expand Down Expand Up @@ -1442,11 +1452,11 @@ include_paths_stanc3_args <- function(include_paths = NULL) {
stancflags
}

model_variables <- function(stan_file) {
model_variables <- function(stan_file, include_paths = NULL) {
out_file <- tempfile(fileext = ".json")
run_log <- processx::run(
command = stanc_cmd(),
args = c(stan_file, "--info"),
args = c(stan_file, "--info", include_paths_stanc3_args(include_paths)),
wd = cmdstan_path(),
echo = FALSE,
echo_cmd = FALSE,
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-fit-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,30 @@ test_that("draws are returned for model with spaces", {
)
expect_equal(dim(fit$draws()), c(1000, 1, 1))
})

test_that("sampling with inits works with include_paths", {
skip_on_cran()

stan_program_w_include <- testing_stan_file("bernoulli_include")
exe <- cmdstan_ext(strip_ext(stan_program_w_include))
if(file.exists(exe)) {
file.remove(exe)
}

expect_interactive_message(
mod_w_include <- cmdstan_model(stan_file = stan_program_w_include, quiet = TRUE,
include_paths = test_path("resources", "stan")),
"Compiling Stan program"
)

data_list <- list(N = 10, y = c(0,1,0,0,0,0,0,0,0,1))

fit <- mod_w_include$sample(
data = data_list,
seed = 123,
chains = 4,
parallel_chains = 4,
refresh = 500,
init = list(list(theta = 0.25), list(theta = 0.25), list(theta = 0.25), list(theta = 0.25))
)
})

0 comments on commit 4bbe4c3

Please sign in to comment.