/
example.R
200 lines (188 loc) · 6.86 KB
/
example.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#' Fit models for use in examples
#'
#' @export
#' @param example (string) The name of the example. The currently available
#' examples are
#' * `"logistic"`: logistic regression with intercept and 3 predictors.
#' * `"schools"`: the so-called "eight schools" model, a hierarchical
#' meta-analysis. Fitting this model will result in warnings about
#' divergences.
#' * `"schools_ncp"`: non-centered parameterization of the "eight schools"
#' model that fixes the problem with divergences.
#'
#' To print the Stan code for a given `example` use
#' `print_example_program(example)`.
#'
#' @param method (string) Which fitting method should be used? The default is
#' the `"sample"` method (MCMC).
#' @param ... Arguments passed to the chosen `method`. See the help pages for
#' the individual methods for details.
#' @param quiet (logical) If `TRUE` (the default) then fitting the model is
#' wrapped in [utils::capture.output()].
#' @param force_recompile Passed to the [$compile()][model-method-compile] method.
#'
#' @return
#' The fitted model object returned by the selected `method`.
#'
#' @examples
#' \dontrun{
#' print_example_program("logistic")
#' fit_logistic_mcmc <- cmdstanr_example("logistic", chains = 2)
#' fit_logistic_mcmc$summary()
#'
#' fit_logistic_optim <- cmdstanr_example("logistic", method = "optimize")
#' fit_logistic_optim$summary()
#'
#' fit_logistic_vb <- cmdstanr_example("logistic", method = "variational")
#' fit_logistic_vb$summary()
#'
#' print_example_program("schools")
#' fit_schools_mcmc <- cmdstanr_example("schools")
#' fit_schools_mcmc$summary()
#'
#' print_example_program("schools_ncp")
#' fit_schools_ncp_mcmc <- cmdstanr_example("schools_ncp")
#' fit_schools_ncp_mcmc$summary()
#'
#' # optimization fails for hierarchical model
#' cmdstanr_example("schools", "optimize", quiet = FALSE)
#' }
#'
cmdstanr_example <-
function(example = c("logistic", "schools", "schools_ncp"),
method = c("sample", "optimize", "laplace", "variational", "pathfinder", "diagnose"),
...,
quiet = TRUE,
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) {
example <- match.arg(example)
method <- match.arg(method)
example_program <- paste0(example, ".stan")
example_data <- paste0(example, ".data.json")
# create executable in temporary directory
tmp <- file.path(tempdir(), example_program)
if (!file.exists(tmp)) {
file.copy(system.file(example_program, package = "cmdstanr"), tmp)
}
mod <- cmdstan_model(tmp, force_recompile = force_recompile)
data_file <- system.file(example_data, package = "cmdstanr")
if (quiet) {
out <- utils::capture.output(fit <- mod[[method]](data = data_file, ...))
} else {
fit <- mod[[method]](data = data_file, ...)
}
fit
}
#' @rdname cmdstanr_example
#' @export
print_example_program <-
function(example = c("logistic", "schools", "schools_ncp")) {
example <- match.arg(example)
code <- readLines(system.file(paste0(example, ".stan"), package = "cmdstanr"))
cat(code, sep = "\n")
}
# including write_stan_file in example.R since it will be mostly used
# in examples
#' Write Stan code to a file
#'
#' Convenience function for writing Stan code to a (possibly
#' [temporary][base::tempfile]) file with a `.stan` extension. By default, the
#' file name is chosen deterministically based on a [hash][rlang::hash()]
#' of the Stan code, and the file is not overwritten if it already has correct
#' contents. This means that calling this function multiple times with the same
#' Stan code will reuse the compiled model. This also however means that the
#' function is potentially not thread-safe. Using `hash_salt = Sys.getpid()`
#' should ensure thread-safety in the rare cases when it is needed.
#'
#' @export
#' @param code (character vector) The Stan code to write to the file. This can
#' be a character vector of length one (a string) containing the entire Stan
#' program or a character vector with each element containing one line of the
#' Stan program.
#' @param dir (string) An optional path to the directory where the file will be
#' written. If omitted, a global option `cmdstanr_write_stan_file_dir` is
#' used. If the global options is not set, [temporary directory][base::tempdir]
#' is used.
#' @param basename (string) If `dir` is specified, optionally the basename to
#' use for the file created. If not specified a file name is generated
#' from [hashing][rlang::hash()] the code.
#' @param force_overwrite (logical) If set to `TRUE` the file will always be
#' overwritten and thus the resulting model will always be recompiled.
#' @param hash_salt (string) Text to add to the model code prior to hashing to
#' determine the file name if `basename` is not set.
#' @return The path to the file.
#'
#' @examples
#' # stan program as a single string
#' stan_program <- "
#' data {
#' int<lower=0> N;
#' array[N] int<lower=0,upper=1> y;
#' }
#' parameters {
#' real<lower=0,upper=1> theta;
#' }
#' model {
#' y ~ bernoulli(theta);
#' }
#' "
#'
#' f <- write_stan_file(stan_program)
#' print(f)
#'
#' lines <- readLines(f)
#' print(lines)
#' cat(lines, sep = "\n")
#'
#' # stan program as character vector of lines
#' f2 <- write_stan_file(lines)
#' identical(readLines(f), readLines(f2))
#'
write_stan_file <- function(code,
dir = getOption("cmdstanr_write_stan_file_dir", tempdir()),
basename = NULL,
force_overwrite = FALSE,
hash_salt = "") {
dir <- absolute_path(dir)
if (!dir.exists(dir)) {
dir.create(dir, recursive = TRUE)
}
collapsed_code <- paste0(code, collapse = "\n")
if (!is.null(basename)) {
if (!endsWith(basename, ".stan")) {
basename <- paste0(basename, ".stan")
}
file <- file.path(dir, basename)
} else {
require_suggested_package("rlang")
hash <- rlang::hash(paste0(hash_salt, collapsed_code))
file <- file.path(dir, paste0("model_", hash, ".stan"))
}
overwrite <- TRUE
# Do not overwrite file if it has the correct contents (to avoid recompilation)
if (!force_overwrite && file.exists(file)) {
tryCatch({
file_contents <- paste0(readLines(file), collapse = "\n")
if (gsub("\r|\n", "\n", file_contents) == gsub("\r|\n", "\n", collapsed_code)) {
overwrite <- FALSE
}
},
error = function(e) {
warning("Error when checking old file contents", e)
})
}
if (overwrite) {
cat(code, file = file, sep = "\n")
}
file
}
#' Write Stan code to a temporary file
#'
#' This function is deprecated. Please use [write_stan_file()] instead.
#' @keywords internal
#' @export
#' @inheritParams write_stan_file
write_stan_tempfile <- function(code, dir = tempdir()) {
warning("write_stan_tempfile() is deprecated. Please use write_stan_file() instead.",
call. = FALSE)
write_stan_file(code, dir)
}