Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add threading for variational/optimize #369

Merged
merged 8 commits into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

### New features

* Added support for native execution on the macOS with the M1 ARM-based CPU.
* Added support for native execution on the macOS with the M1 ARM-based CPU. (#375)

* Added `threads` argument for `$optimize()` and `$variational()`. (#369)

# cmdstanr 0.2.1

Expand Down
83 changes: 75 additions & 8 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ sample_method <- function(data = NULL,
}
if (!is.null(num_cores)) {
warning("'num_cores' is deprecated. Please use 'parallel_chains' instead.")
cores <- num_cores
parallel_chains <- num_cores
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
}
if (!is.null(num_chains)) {
warning("'num_chains' is deprecated. Please use 'chains' instead.")
Expand Down Expand Up @@ -988,6 +988,7 @@ CmdStanModel$set("public", name = "sample", value = sample_method)
#' init = NULL,
#' save_latent_dynamics = FALSE,
#' output_dir = NULL,
#' threads = NULL,
#' algorithm = NULL,
#' init_alpha = NULL,
#' iter = NULL,
Expand All @@ -1002,6 +1003,10 @@ CmdStanModel$set("public", name = "sample", value = sample_method)
#' in the CmdStan manual. Arguments left at `NULL` default to the default used
#' by the installed version of CmdStan.
#'
#' * `threads`: (positive integer) If the model was
#' [compiled][model-method-compile] with threading support, the number of
#' threads to use in parallelized sections (e.g., when
#' using the Stan functions `reduce_sum()` or `map_rect()`).
#' * `algorithm`: (string) The optimization algorithm. One of `"lbfgs"`,
#' `"bfgs"`, or `"newton"`.
#' * `iter`: (positive integer) The number of iterations.
Expand All @@ -1021,10 +1026,26 @@ optimize_method <- function(data = NULL,
init = NULL,
save_latent_dynamics = FALSE,
output_dir = NULL,
threads = NULL,
algorithm = NULL,
init_alpha = NULL,
iter = NULL,
sig_figs = NULL) {
checkmate::assert_integerish(threads, lower = 1, len = 1, null.ok = TRUE)
if (is.null(self$cpp_options()[["stan_threads"]])) {
if (!is.null(threads)) {
warning("'threads' is set but the model was not compiled with ",
"'cpp_options = list(stan_threads = TRUE)' so 'threads' will have no effect!",
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
call. = FALSE)
threads <- NULL
}
} else {
if (is.null(threads)) {
stop("The model was compiled with 'cpp_options = list(stan_threads = TRUE)' ",
"but 'threads' was not set!",
call. = FALSE)
}
}
optimize_args <- OptimizeArgs$new(
algorithm = algorithm,
init_alpha = init_alpha,
Expand All @@ -1044,7 +1065,11 @@ optimize_method <- function(data = NULL,
sig_figs = sig_figs
)

cmdstan_procs <- CmdStanProcs$new(num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0))
cmdstan_procs <- CmdStanProcs$new(
num_procs = 1,
show_stdout_messages = (is.null(refresh) || refresh != 0),
threads_per_proc = threads
)
runset <- CmdStanRun$new(args = cmdstan_args, procs = cmdstan_procs)
runset$run_cmdstan()
CmdStanMLE$new(runset)
Expand Down Expand Up @@ -1079,6 +1104,7 @@ CmdStanModel$set("public", name = "optimize", value = optimize_method)
#' init = NULL,
#' save_latent_dynamics = FALSE,
#' output_dir = NULL,
#' threads = NULL,
#' algorithm = NULL,
#' iter = NULL,
#' grad_samples = NULL,
Expand All @@ -1100,6 +1126,10 @@ CmdStanModel$set("public", name = "optimize", value = optimize_method)
#' in the CmdStan manual. Arguments left at `NULL` default to the default used
#' by the installed version of CmdStan.
#'
#' * `threads`: (positive integer) If the model was
#' [compiled][model-method-compile] with threading support, the number of
#' threads to use in parallelized sections (e.g., when
#' using the Stan functions `reduce_sum()` or `map_rect()`).
#' * `algorithm`: (string) The algorithm. Either `"meanfield"` or `"fullrank"`.
#' * `iter`: (positive integer) The _maximum_ number of iterations.
#' * `grad_samples`: (positive integer) The number of samples for Monte Carlo
Expand All @@ -1117,6 +1147,7 @@ CmdStanModel$set("public", name = "optimize", value = optimize_method)
#' * `output_samples:` (positive integer) Number of posterior samples to draw
#' and save.
#'
#'
#' @section Value: The `$variational()` method returns a [`CmdStanVB`] object.
#'
#' @template seealso-docs
Expand All @@ -1130,6 +1161,7 @@ variational_method <- function(data = NULL,
init = NULL,
save_latent_dynamics = FALSE,
output_dir = NULL,
threads = NULL,
algorithm = NULL,
iter = NULL,
grad_samples = NULL,
Expand All @@ -1141,6 +1173,21 @@ variational_method <- function(data = NULL,
eval_elbo = NULL,
output_samples = NULL,
sig_figs = NULL) {
checkmate::assert_integerish(threads, lower = 1, len = 1, null.ok = TRUE)
if (is.null(self$cpp_options()[["stan_threads"]])) {
if (!is.null(threads)) {
warning("'threads' is set but the model was not compiled with ",
"'cpp_options = list(stan_threads = TRUE)' so 'threads' will have no effect!",
call. = FALSE)
threads <- NULL
}
} else {
if (is.null(threads)) {
stop("The model was compiled with 'cpp_options = list(stan_threads = TRUE)' ",
"but 'threads' was not set!",
call. = FALSE)
}
}
variational_args <- VariationalArgs$new(
algorithm = algorithm,
iter = iter,
Expand All @@ -1167,7 +1214,11 @@ variational_method <- function(data = NULL,
sig_figs = sig_figs
)

cmdstan_procs <- CmdStanProcs$new(num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0))
cmdstan_procs <- CmdStanProcs$new(
num_procs = 1,
show_stdout_messages = (is.null(refresh) || refresh != 0),
threads_per_proc = threads
)
runset <- CmdStanRun$new(args = cmdstan_args, procs = cmdstan_procs)
runset$run_cmdstan()
CmdStanVB$new(runset)
Expand All @@ -1191,9 +1242,9 @@ CmdStanModel$set("public", name = "variational", value = variational_method)
#' data = NULL,
#' seed = NULL,
#' output_dir = NULL,
#' sig_figs = NULL,
#' parallel_chains = getOption("mc.cores", 1),
#' threads_per_chain = NULL
#' threads_per_chain = NULL,
#' sig_figs = NULL
#' )
#' ```
#'
Expand All @@ -1202,7 +1253,7 @@ CmdStanModel$set("public", name = "variational", value = variational_method)
#' - A [CmdStanMCMC] fitted model object.
#' - A character vector of paths to CmdStan CSV output files containing
#' parameter draws.
#' * `data`, `seed`, `output_dir`, `parallel_chains`, `threads_per_chain`:
#' * `data`, `seed`, `output_dir`, `parallel_chains`, `threads_per_chain`, `sig_figs`:
#' Same as for the [`$sample()`][model-method-sample] method.
#'
#' @section Value: The `$generate_quantities()` method returns a [`CmdStanGQ`] object.
Expand Down Expand Up @@ -1258,10 +1309,26 @@ generate_quantities_method <- function(fitted_params,
data = NULL,
seed = NULL,
output_dir = NULL,
sig_figs = NULL,
parallel_chains = getOption("mc.cores", 1),
threads_per_chain = NULL) {
threads_per_chain = NULL,
sig_figs = NULL) {
checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE)
checkmate::assert_integerish(threads_per_chain, lower = 1, len = 1, null.ok = TRUE)
if (is.null(self$cpp_options()[["stan_threads"]])) {
if (!is.null(threads_per_chain)) {
warning("'threads_per_chain' is set but the model was not compiled with ",
"'cpp_options = list(stan_threads = TRUE)' so 'threads_per_chain' will have no effect!",
call. = FALSE)
threads_per_chain <- NULL
}
} else {
if (is.null(threads_per_chain)) {
stop("The model was compiled with 'cpp_options = list(stan_threads = TRUE)' ",
"but 'threads_per_chain' was not set!",
call. = FALSE)
}
}

fitted_params <- process_fitted_params(fitted_params)
chains <- length(fitted_params)
generate_quantities_args <- GenerateQuantitiesArgs$new(
Expand Down
8 changes: 6 additions & 2 deletions R/read_csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ read_cmdstan_csv <- function(files,
posterior::as_draws_array(draws[(num_warmup_draws+1):all_draws, variables, drop = FALSE]),
along="chain"
)
}
}
}
if (length(sampler_diagnostics) > 0) {
warmup_sampler_diagnostics_draws <- posterior::bind_draws(
Expand Down Expand Up @@ -483,7 +483,11 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$step_size <- csv_file_info$stepsize
csv_file_info$iter_warmup <- csv_file_info$num_warmup
csv_file_info$iter_sampling <- csv_file_info$num_samples
csv_file_info$threads_per_chain <- csv_file_info$num_threads
if (csv_file_info$method == "variational" || csv_file_info$method == "optimize") {
csv_file_info$threads <- csv_file_info$num_threads
} else {
csv_file_info$threads_per_chain <- csv_file_info$num_threads
}
csv_file_info$model <- NULL
csv_file_info$engaged <- NULL
csv_file_info$delta <- NULL
Expand Down
3 changes: 3 additions & 0 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ CmdStanRun$set("private", name = "run_generate_quantities_", value = .run_genera
Sys.setenv(PATH = paste0(path_to_TBB, ";", Sys.getenv("PATH")))
}
}
if (!is.null(procs$threads_per_proc())) {
Sys.setenv("STAN_NUM_THREADS" = as.integer(procs$threads_per_proc()))
}
start_time <- Sys.time()
id <- 1
procs$new_proc(
Expand Down
6 changes: 3 additions & 3 deletions man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/model-method-variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/testthat/test-model-generate_quantities.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ test_that("generate_quantities work for different chains and parallel_chains", {
expect_gq_output(
mod_gq$generate_quantities(data = data_list, fitted_params = fit, parallel_chains = 4)
)
mod_gq <- cmdstan_model(testing_stan_file("bernoulli_ppc"), cpp_options = list(stan_threads = TRUE))
expect_gq_output(
mod_gq$generate_quantities(data = data_list, fitted_params = fit_1_chain, threads_per_chain = 2)
)
Expand Down
Loading