Skip to content

Commit

Permalink
add output_dir argument to fitting methods
Browse files Browse the repository at this point in the history
closes #89
  • Loading branch information
jgabry committed Nov 19, 2019
1 parent afbec58 commit d48ee65
Show file tree
Hide file tree
Showing 14 changed files with 216 additions and 62 deletions.
39 changes: 19 additions & 20 deletions R/args.R
Expand Up @@ -34,7 +34,8 @@ CmdStanArgs <- R6::R6Class(
save_diagnostics = FALSE,
seed = NULL,
init = NULL,
refresh = NULL) {
refresh = NULL,
output_dir = NULL) {

self$model_name <- model_name
self$exe_file <- exe_file
Expand All @@ -46,12 +47,14 @@ CmdStanArgs <- R6::R6Class(
self$method_args <- method_args
self$method <- self$method_args$method
self$save_diagnostics <- save_diagnostics
self$output_dir <- output_dir %||% tempdir(check = TRUE)

self$method_args$validate(num_runs = length(self$run_ids))
self$validate()
},
validate = function() {
validate_cmdstan_args(self)
self$output_dir <- absolute_path(self$output_dir)
if (is.character(self$data_file)) {
self$data_file <- absolute_path(self$data_file)
}
Expand All @@ -63,27 +66,22 @@ CmdStanArgs <- R6::R6Class(
invisible(self)
},

tempfile_basename = function(type = c("output", "diagnostic")) {
new_file_names = function(type = c("output", "diagnostic")) {
basename <- self$model_name
type <- match.arg(type)
paste0(self$model_name, "-stan-", self$method, "-",
if (type == "diagnostic") "diagnostic-",
self$run_ids, "-")
},
new_output_files = function() {
files <- tempfile(
pattern = self$tempfile_basename("output"),
tmpdir = cmdstan_tempdir(),
fileext = ".csv"
if (type == "diagnostic") {
basename <- paste0(basename, "-diagnostic")
}
generate_file_names( # defined in utils.R
basename = basename,
ext = ".csv",
ids = self$run_ids,
timestamp = TRUE,
random = TRUE
)
invisible(file.create(files))
files
},
new_diagnostic_files = function() {
files <- tempfile(
pattern = self$tempfile_basename("diagnostic"),
tmpdir = cmdstan_tempdir(),
fileext = ".csv"
)
new_files = function(type = c("output", "diagnostic")) {
files <- file.path(self$output_dir, self$new_file_names(type))
invisible(file.create(files))
files
},
Expand Down Expand Up @@ -382,9 +380,10 @@ VariationalArgs <- R6::R6Class(
#' @param self A `CmdStanArgs` object.
#' @return `TRUE` invisibly unless an error is thrown.
validate_cmdstan_args = function(self) {
# TODO: validate that can write to output directory
validate_exe_file(self$exe_file)

checkmate::assert_directory_exists(self$output_dir, access = "rw")

# at least 1 run id (chain id)
checkmate::assert_integerish(self$run_ids,
lower = 1,
Expand Down
14 changes: 7 additions & 7 deletions R/fit.R
Expand Up @@ -131,12 +131,12 @@ NULL
#' @aliases fit-method-save_data_file fit-method-save_diagnostic_files
#' fit-method-output_files fit-method-data_file fit-method-diagnostic_files
#'
#' @description All fitted model objects have methods for saving (copying to a
#' specified location) the temporary files created by CmdStanR for CmdStan
#' output csv files and input data files. These methods move the files from
#' the CmdStanR temporary directory to a user-specified location. __The paths
#' stored in the fitted model object will also be updated to point to the new
#' file locations.__
#' @description All fitted model objects have methods for saving (moving to a
#' specified location) the files created by CmdStanR to hold CmdStan output
#' csv files and input data files. These methods move the files from their
#' current location (possibly the temporary directory) to a user-specified
#' location. __The paths stored in the fitted model object will also be
#' updated to point to the new file locations.__
#'
#' The versions without the `save_` prefix (e.g., `$output_files()`) return
#' the current file paths without moving any files.
Expand Down Expand Up @@ -166,7 +166,7 @@ NULL
#' * `basename` is the user's provided `basename` argument;
#' * `timestamp` is of the form `format(Sys.time(), "%Y%m%d%H%M")`;
#' * `id` is the MCMC chain id (or `1` for non MCMC);
#' * `random` contains five random alphanumeric characters/
#' * `random` contains six random alphanumeric characters.
#'
#' For `$save_diagnostic_files()` everything is the same as for
#' `$save_output_files()` except `"-diagnostic-"` is included in the new
Expand Down
15 changes: 12 additions & 3 deletions R/model.R
Expand Up @@ -278,6 +278,7 @@ CmdStanModel$set("public", name = "compile", value = compile_method)
#' refresh = NULL,
#' init = NULL,
#' save_diagnostics = FALSE,
#' output_dir = NULL,
#' num_chains = 4,
#' num_cores = getOption("mc.cores", 1),
#' num_warmup = NULL,
Expand Down Expand Up @@ -378,6 +379,7 @@ sample_method <- function(data = NULL,
refresh = NULL,
init = NULL,
save_diagnostics = FALSE,
output_dir = NULL,
num_chains = 4,
num_cores = getOption("mc.cores", 1),
num_warmup = NULL,
Expand Down Expand Up @@ -422,7 +424,8 @@ sample_method <- function(data = NULL,
save_diagnostics = save_diagnostics,
seed = seed,
init = init,
refresh = refresh
refresh = refresh,
output_dir = output_dir
)
cmdstan_procs <- CmdStanProcs$new(num_chains, num_cores)
runset <- CmdStanRun$new(cmdstan_args, cmdstan_procs)
Expand Down Expand Up @@ -457,6 +460,7 @@ CmdStanModel$set("public", name = "sample", value = sample_method)
#' refresh = NULL,
#' init = NULL,
#' save_diagnostics = FALSE,
#' output_dir = NULL,
#' algorithm = NULL,
#' init_alpha = NULL,
#' iter = NULL
Expand Down Expand Up @@ -488,6 +492,7 @@ optimize_method <- function(data = NULL,
refresh = NULL,
init = NULL,
save_diagnostics = FALSE,
output_dir = NULL,
algorithm = NULL,
init_alpha = NULL,
iter = NULL) {
Expand All @@ -505,7 +510,8 @@ optimize_method <- function(data = NULL,
save_diagnostics = save_diagnostics,
seed = seed,
init = init,
refresh = refresh
refresh = refresh,
output_dir = output_dir
)

cmdstan_procs <- CmdStanProcs$new(num_runs = 1, num_cores = 1)
Expand Down Expand Up @@ -546,6 +552,7 @@ CmdStanModel$set("public", name = "optimize", value = optimize_method)
#' refresh = NULL,
#' init = NULL,
#' save_diagnostics = FALSE,
#' output_dir = NULL,
#' algorithm = NULL,
#' iter = NULL,
#' grad_samples = NULL,
Expand Down Expand Up @@ -595,6 +602,7 @@ variational_method <- function(data = NULL,
refresh = NULL,
init = NULL,
save_diagnostics = FALSE,
output_dir = NULL,
algorithm = NULL,
iter = NULL,
grad_samples = NULL,
Expand Down Expand Up @@ -626,7 +634,8 @@ variational_method <- function(data = NULL,
save_diagnostics = save_diagnostics,
seed = seed,
init = init,
refresh = refresh
refresh = refresh,
output_dir = output_dir
)

cmdstan_procs <- CmdStanProcs$new(num_runs = 1, num_cores = 1)
Expand Down
8 changes: 6 additions & 2 deletions R/run.R
Expand Up @@ -27,8 +27,12 @@ CmdStanRun <- R6::R6Class(
model_name = function() self$args$model_name,
method = function() self$args$method,
data_file = function() self$args$data_file,
new_output_files = function() self$args$new_output_files(),
new_diagnostic_files = function() self$args$new_diagnostic_files(),
new_output_files = function() {
self$args$new_files(type = "output")
},
new_diagnostic_files = function() {
self$args$new_files(type = "diagnostic")
},
diagnostic_files = function() {
if (!length(private$diagnostic_files_)) {
stop(
Expand Down
57 changes: 37 additions & 20 deletions R/utils.R
Expand Up @@ -80,7 +80,7 @@ strip_ext <- function(file) {

# If a file/dir exists return its absolute path
# doesn't error if not found
absolute_path <- function(path) {
.absolute_path <- function(path) {
if (file.exists(path)) {
new_path <- repair_path(path)
} else {
Expand All @@ -92,7 +92,7 @@ absolute_path <- function(path) {
}
repair_path(file.path(getwd(), new_path))
}
absolute_path <- Vectorize(absolute_path, USE.NAMES = FALSE)
absolute_path <- Vectorize(.absolute_path, USE.NAMES = FALSE)



Expand Down Expand Up @@ -122,8 +122,37 @@ copy_temp_files <-
random = TRUE,
ext = ".csv") {
checkmate::assert_directory_exists(new_dir, access = "w")
destinations <- generate_file_names(
basename = new_basename,
ext = ext,
ids = ids,
timestamp = timestamp,
random = random
)
if (new_dir != ".") {
destinations <- file.path(new_dir, destinations)
}

copied <- file.copy(
from = current_paths,
to = destinations,
overwrite = TRUE
)
if (!all(copied)) {
destinations[!copied] <- NA_character_
}
absolute_path(destinations)
}

new_names <- new_basename
# generate new file names
# see doc above for copy_temp_files
generate_file_names <-
function(basename,
ext = ".csv",
ids = NULL,
timestamp = TRUE,
random = TRUE) {
new_names <- basename
if (timestamp) {
stamp <- format(Sys.time(), "%Y%m%d%H%M")
new_names <- paste0(new_names, "-", stamp)
Expand All @@ -134,27 +163,15 @@ copy_temp_files <-

if (random) {
tf <- tempfile()
rand <- substr(tf, nchar(tf) - 4, nchar(tf))
rand <- substr(tf, nchar(tf) - 5, nchar(tf))
new_names <- paste0(new_names, "-", rand)
}

ext <- if (startsWith(ext, ".")) ext else paste0(".", ext)
new_names <- paste0(new_names, ext)
if (new_dir == ".") {
destinations <- new_names
} else {
destinations <- file.path(new_dir, new_names)
}

copied <- file.copy(
from = current_paths,
to = destinations,
overwrite = TRUE
)
if (!all(copied)) {
destinations[!copied] <- NA_character_
if (length(ext)) {
ext <- if (startsWith(ext, ".")) ext else paste0(".", ext)
new_names <- paste0(new_names, ext)
}
absolute_path(destinations)
new_names
}


Expand Down
13 changes: 13 additions & 0 deletions man-roxygen/model-common-args.R
Expand Up @@ -28,3 +28,16 @@
#' `save_diagnostics=TRUE` see the
#' [`$save_diagnostic_files()`][fit-method-save_diagnostic_files] method.
#'
#' * `output_dir`: (string) A path to a directory where CmdStan should write
#' its output CSV files. For interactive use this can typically be left at
#' `NULL` (temporary directory) since CmdStanR makes the CmdStan output (e.g.,
#' posterior draws and diagnostics) available in \R via methods of the fitted
#' model objects. The behavior of `output_dir` is as follows:
#' - If `NULL` (the default) then the CSV files are written to a temporary
#' directory and only saved permanently if the user calls one of the
#' `$save_*` methods of the fitted model object (e.g.,
#' [`$save_output_files()`][fit-method-save_output_files]).
#' - If a path then the files are created in `output_dir` with names
#' corresponding the defaults used by `$save_output_files()` (and similar
#' methods like `$save_diagnostic_files()`).
#'
14 changes: 7 additions & 7 deletions man/fit-method-save_output_files.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d48ee65

Please sign in to comment.