/
stancode.R
499 lines (477 loc) · 17.6 KB
/
stancode.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
#' @title Stan Code for Bayesian models
#'
#' @description \code{stancode} is a generic function that can be used to
#' generate Stan code for Bayesian models. Its original use is
#' within the \pkg{brms} package, but new methods for use
#' with objects from other packages can be registered to the same generic.
#'
#' @param object An object whose class will determine which method to apply.
#' Usually, it will be some kind of symbolic description of the model
#' form which Stan code should be generated.
#' @param formula Synonym of \code{object} for use in \code{make_stancode}.
#' @param ... Further arguments passed to the specific method.
#'
#' @return Usually, a character string containing the generated Stan code.
#' For pretty printing, we recommend the returned object to be of class
#' \code{c("character", "brmsmodel")}.
#'
#' @details
#' See \code{\link[brms:stancode.default]{stancode.default}} for the default
#' method applied for \pkg{brms} models.
#' You can view the available methods by typing: \code{methods(stancode)}
#' The \code{make_stancode} function is an alias of \code{stancode}.
#'
#' @seealso
#' \code{\link{stancode.default}}, \code{\link{stancode.brmsfit}}
#'
#' @examples
#' stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#'
#' @export
stancode <- function(object, ...) {
UseMethod("stancode")
}
#' @rdname stancode
#' @export
make_stancode <- function(formula, ...) {
# became an alias of 'stancode' in 2.20.14
stancode(formula, ...)
}
#' Stan Code for \pkg{brms} Models
#'
#' Generate Stan code for \pkg{brms} models
#'
#' @inheritParams brm
#' @param object An object of class \code{\link[stats:formula]{formula}},
#' \code{\link{brmsformula}}, or \code{\link{mvbrmsformula}} (or one that can
#' be coerced to that classes): A symbolic description of the model to be
#' fitted. The details of model specification are explained in
#' \code{\link{brmsformula}}.
#' @param ... Other arguments for internal usage only.
#'
#' @return A character string containing the fully commented \pkg{Stan} code
#' to fit a \pkg{brms} model. It is of class \code{c("character", "brmsmodel")}
#' to facilitate pretty printing.
#'
#' @examples
#' stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#'
#' stancode(count ~ zAge + zBase * Trt + (1|patient),
#' data = epilepsy, family = "poisson")
#'
#' @export
stancode.default <- function(object, data, family = gaussian(),
prior = NULL, autocor = NULL, data2 = NULL,
cov_ranef = NULL, sparse = NULL,
sample_prior = "no", stanvars = NULL,
stan_funs = NULL, knots = NULL,
drop_unused_levels = TRUE,
threads = getOption("brms.threads", NULL),
normalize = getOption("brms.normalize", TRUE),
save_model = NULL, ...) {
object <- validate_formula(
object, data = data, family = family,
autocor = autocor, sparse = sparse,
cov_ranef = cov_ranef
)
bterms <- brmsterms(object)
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(object),
get_data2_cov_ranef(object)
)
data <- validate_data(
data, bterms = bterms,
data2 = data2, knots = knots,
drop_unused_levels = drop_unused_levels
)
prior <- .validate_prior(
prior, bterms = bterms, data = data,
sample_prior = sample_prior
)
stanvars <- validate_stanvars(stanvars, stan_funs = stan_funs)
threads <- validate_threads(threads)
.stancode(
bterms, data = data, prior = prior,
stanvars = stanvars, threads = threads,
normalize = normalize, save_model = save_model,
...
)
}
# internal work function of 'stancode.default'
# @param parse parse the Stan model for automatic syntax checking
# @param backend name of the backend used for parsing
# @param silent silence parsing messages
.stancode <- function(bterms, data, prior, stanvars,
threads = threading(),
normalize = getOption("brms.normalize", TRUE),
parse = getOption("brms.parse_stancode", FALSE),
backend = getOption("brms.backend", "rstan"),
silent = TRUE, save_model = NULL, ...) {
normalize <- as_one_logical(normalize)
parse <- as_one_logical(parse)
backend <- match.arg(backend, backend_choices())
silent <- as_one_logical(silent)
ranef <- tidy_ranef(bterms, data = data)
meef <- tidy_meef(bterms, data = data)
scode_predictor <- stan_predictor(
bterms, data = data, prior = prior,
normalize = normalize, ranef = ranef, meef = meef,
stanvars = stanvars, threads = threads
)
scode_ranef <- stan_re(
ranef, prior = prior, threads = threads, normalize = normalize
)
scode_Xme <- stan_Xme(
meef, prior = prior, threads = threads, normalize = normalize
)
scode_global_defs <- stan_global_defs(
bterms, prior = prior, ranef = ranef, threads = threads
)
# extend Stan's likelihood part
if (use_threading(threads)) {
# threading is activated
for (i in seq_along(scode_predictor)) {
resp <- usc(names(scode_predictor)[i])
pll_args <- stan_clean_pll_args(
scode_predictor[[i]][["pll_args"]],
scode_ranef[["pll_args"]],
scode_Xme[["pll_args"]],
collapse_stanvars_pll_args(stanvars)
)
partial_log_lik <- paste0(
scode_predictor[[i]][["pll_def"]],
scode_predictor[[i]][["model_def"]],
collapse_stanvars(stanvars, "likelihood", "start"),
scode_predictor[[i]][["model_comp_basic"]],
scode_predictor[[i]][["model_comp_eta_basic"]],
scode_predictor[[i]][["model_comp_eta_loop"]],
scode_predictor[[i]][["model_comp_dpar_link"]],
scode_predictor[[i]][["model_comp_dpar_trans"]],
scode_predictor[[i]][["model_comp_mix"]],
scode_predictor[[i]][["model_comp_arma"]],
scode_predictor[[i]][["model_comp_catjoin"]],
scode_predictor[[i]][["model_comp_mvjoin"]],
scode_predictor[[i]][["model_log_lik"]],
collapse_stanvars(stanvars, "likelihood", "end")
)
partial_log_lik <- gsub(" target \\+=", " ptarget +=", partial_log_lik)
partial_log_lik <- paste0(
"// compute partial sums of the log-likelihood\n",
"real partial_log_lik", resp, "_lpmf(array[] int seq", resp,
", int start, int end", pll_args$typed, ") {\n",
" real ptarget = 0;\n",
" int N = end - start + 1;\n",
partial_log_lik,
" return ptarget;\n",
"}\n"
)
partial_log_lik <- wsp_per_line(partial_log_lik, 2)
scode_predictor[[i]][["partial_log_lik"]] <- partial_log_lik
static <- str_if(threads$static, "_static")
scode_predictor[[i]][["model_lik"]] <- paste0(
" target += reduce_sum", static, "(partial_log_lik", resp, "_lpmf",
", seq", resp, ", grainsize", pll_args$plain, ");\n"
)
str_add(scode_predictor[[i]][["tdata_def"]]) <- glue(
" array[N{resp}] int seq{resp} = sequence(1, N{resp});\n"
)
}
scode_predictor <- collapse_lists(ls = scode_predictor)
scode_predictor[["model_lik"]] <- paste0(
scode_predictor[["model_no_pll_def"]],
scode_predictor[["model_no_pll_comp_basic"]],
scode_predictor[["model_no_pll_comp_mvjoin"]],
scode_predictor[["model_lik"]]
)
str_add(scode_predictor[["data"]]) <-
" int grainsize; // grainsize for threading\n"
} else {
# threading is not activated
scode_predictor <- collapse_lists(ls = scode_predictor)
scode_predictor[["model_lik"]] <- paste0(
scode_predictor[["model_no_pll_def"]],
scode_predictor[["model_def"]],
collapse_stanvars(stanvars, "likelihood", "start"),
scode_predictor[["model_no_pll_comp_basic"]],
scode_predictor[["model_comp_basic"]],
scode_predictor[["model_comp_eta_basic"]],
scode_predictor[["model_comp_eta_loop"]],
scode_predictor[["model_comp_dpar_link"]],
scode_predictor[["model_comp_dpar_trans"]],
scode_predictor[["model_comp_mix"]],
scode_predictor[["model_comp_arma"]],
scode_predictor[["model_comp_catjoin"]],
scode_predictor[["model_no_pll_comp_mvjoin"]],
scode_predictor[["model_comp_mvjoin"]],
scode_predictor[["model_log_lik"]],
collapse_stanvars(stanvars, "likelihood", "end")
)
}
scode_predictor[["model_lik"]] <-
wsp_per_line(scode_predictor[["model_lik"]], 2)
# get all priors added to 'lprior'
scode_tpar_prior <- paste0(
scode_predictor[["tpar_prior"]],
scode_ranef[["tpar_prior"]],
scode_Xme[["tpar_prior"]]
)
# generate functions block
scode_functions <- paste0(
"// generated with brms ", utils::packageVersion("brms"), "\n",
"functions {\n",
scode_global_defs[["fun"]],
collapse_stanvars(stanvars, "functions"),
scode_predictor[["partial_log_lik"]],
"}\n"
)
# generate data block
scode_data <- paste0(
"data {\n",
" int<lower=1> N; // total number of observations\n",
scode_predictor[["data"]],
scode_ranef[["data"]],
scode_Xme[["data"]],
" int prior_only; // should the likelihood be ignored?\n",
collapse_stanvars(stanvars, "data"),
"}\n"
)
# generate transformed parameters block
scode_transformed_data <- paste0(
"transformed data {\n",
scode_global_defs[["tdata_def"]],
scode_predictor[["tdata_def"]],
collapse_stanvars(stanvars, "tdata", "start"),
scode_predictor[["tdata_comp"]],
collapse_stanvars(stanvars, "tdata", "end"),
"}\n"
)
# generate parameters block
scode_parameters <- paste0(
scode_predictor[["par"]],
scode_ranef[["par"]],
scode_Xme[["par"]]
)
# prepare additional sampling from priors
scode_rngprior <- stan_rngprior(
tpar_prior = scode_tpar_prior,
par_declars = scode_parameters,
gen_quantities = scode_predictor[["gen_def"]],
special_prior = attr(prior, "special"),
sample_prior = get_sample_prior(prior)
)
scode_parameters <- paste0(
"parameters {\n",
scode_parameters,
scode_rngprior[["par"]],
collapse_stanvars(stanvars, "parameters"),
"}\n"
)
# generate transformed parameters block
scode_lprior_def <- " real lprior = 0; // prior contributions to the log posterior\n"
scode_transformed_parameters <- paste0(
"transformed parameters {\n",
scode_predictor[["tpar_def"]],
scode_ranef[["tpar_def"]],
scode_Xme[["tpar_def"]],
str_if(normalize, scode_lprior_def),
collapse_stanvars(stanvars, "tparameters", "start"),
scode_predictor[["tpar_prior_const"]],
scode_ranef[["tpar_prior_const"]],
scode_Xme[["tpar_prior_const"]],
scode_predictor[["tpar_comp"]],
scode_predictor[["tpar_special_prior"]],
scode_ranef[["tpar_comp"]],
scode_Xme[["tpar_comp"]],
# lprior cannot contain _lupdf functions in transformed parameters
# as discussed on github.com/stan-dev/stan/issues/3094
str_if(normalize, scode_tpar_prior),
collapse_stanvars(stanvars, "tparameters", "end"),
"}\n"
)
# combine likelihood with prior part
not_const <- str_if(!normalize, " not")
scode_model <- paste0(
"model {\n",
str_if(!normalize, scode_lprior_def),
collapse_stanvars(stanvars, "model", "start"),
" // likelihood", not_const, " including constants\n",
" if (!prior_only) {\n",
scode_predictor[["model_lik"]],
" }\n",
" // priors", not_const, " including constants\n",
str_if(!normalize, scode_tpar_prior),
" target += lprior;\n",
scode_predictor[["model_prior"]],
scode_ranef[["model_prior"]],
scode_Xme[["model_prior"]],
stan_unchecked_prior(prior),
collapse_stanvars(stanvars, "model", "end"),
"}\n"
)
# generate generated quantities block
scode_generated_quantities <- paste0(
"generated quantities {\n",
scode_predictor[["gen_def"]],
scode_ranef[["gen_def"]],
scode_Xme[["gen_def"]],
scode_rngprior[["gen_def"]],
collapse_stanvars(stanvars, "genquant", "start"),
scode_predictor[["gen_comp"]],
scode_ranef[["gen_comp"]],
scode_rngprior[["gen_comp"]],
scode_Xme[["gen_comp"]],
collapse_stanvars(stanvars, "genquant", "end"),
"}\n"
)
# combine all elements into a complete Stan model
scode <- paste0(
scode_functions,
scode_data,
scode_transformed_data,
scode_parameters,
scode_transformed_parameters,
scode_model,
scode_generated_quantities
)
scode <- expand_include_statements(scode)
if (parse) {
scode <- parse_model(scode, backend, silent = silent)
}
# if (backend == "cmdstanr") {
# if (requireNamespace("cmdstanr", quietly = TRUE) &&
# cmdstanr::cmdstan_version() >= "2.29.0") {
# tmp_file <- cmdstanr::write_stan_file(scode)
# scode <- .canonicalize_stan_model(tmp_file, overwrite_file = FALSE)
# }
# }
if (is.character(save_model)) {
cat(scode, file = save_model)
}
class(scode) <- c("character", "brmsmodel")
scode
}
#' @export
print.brmsmodel <- function(x, ...) {
cat(x)
invisible(x)
}
#' Extract Stan code from \code{brmsfit} objects
#'
#' Extract Stan code from a fitted \pkg{brms} model.
#'
#' @param object An object of class \code{brmsfit}.
#' @param version Logical; indicates if the first line containing the \pkg{brms}
#' version number should be included. Defaults to \code{TRUE}.
#' @param regenerate Logical; indicates if the Stan code should be regenerated
#' with the current \pkg{brms} version. By default, \code{regenerate} will be
#' \code{FALSE} unless required to be \code{TRUE} by other arguments.
#' @param threads Controls whether the Stan code should be threaded. See
#' \code{\link{threading}} for details.
#' @param backend Controls the Stan backend. See \code{\link{brm}} for details.
#' @param ... Further arguments passed to
#' \code{\link[brms:stancode.default]{stancode}} if the Stan code is
#' regenerated.
#'
#' @return Stan code for further processing.
#'
#' @export
stancode.brmsfit <- function(object, version = TRUE, regenerate = NULL,
threads = NULL, backend = NULL, ...) {
if (is.null(regenerate)) {
# determine whether regenerating the Stan code is required
regenerate <- FALSE
cl <- match.call()
if ("threads" %in% names(cl)) {
threads <- validate_threads(threads)
if (use_threading(threads) && !use_threading(object$threads) ||
!use_threading(threads) && use_threading(object$threads)) {
# threading changed; regenerated Stan code
regenerate <- TRUE
}
object$threads <- threads
}
if ("backend" %in% names(cl)) {
backend <- match.arg(backend, backend_choices())
# older Stan versions do not support array syntax
if (require_old_stan_syntax(object, backend, "2.29.0")) {
regenerate <- TRUE
}
object$backend <- backend
}
}
regenerate <- as_one_logical(regenerate)
if (regenerate) {
object <- restructure(object)
out <- make_stancode(
formula = object$formula,
data = object$data,
prior = object$prior,
data2 = object$data2,
stanvars = object$stanvars,
sample_prior = get_sample_prior(object$prior),
threads = object$threads,
backend = object$backend,
...
)
} else {
# extract Stan code unaltered
out <- object$model
}
if (!version) {
out <- sub("^[^\n]+[[:digit:]]\\.[^\n]+\n", "", out)
}
out
}
# expand '#include' statements
# This could also be done automatically by Stan at compilation time
# but would result in Stan code that is not self-contained until compilation
# @param model Stan code potentially including '#include' statements
# @return Stan code with '#include' statements expanded
expand_include_statements <- function(model) {
path <- system.file("chunks", package = "brms")
includes <- get_matches("#include '[^']+'", model)
# removal of duplicates could make code generation easier in the future
includes <- unique(includes)
files <- gsub("(#include )|(')", "", includes)
for (i in seq_along(includes)) {
code <- readLines(paste0(path, "/", files[i]))
code <- paste0(code, collapse = "\n")
pattern <- paste0(" *", escape_all(includes[i]))
model <- sub(pattern, code, model)
}
model
}
# check if Stan code includes normalization constants
is_normalized <- function(stancode) {
!grepl("_lup(d|m)f\\(", stancode)
}
# Normalizes Stan code to avoid triggering refit after whitespace and
# comment changes in the generated code.
# In some distant future, StanC3 may provide its own normalizing functions,
# until then this is a set of regex hacks.
# @param x a string containing the Stan code
normalize_stancode <- function(x) {
x <- as_one_character(x)
# Remove single-line comments
x <- gsub("//[^\n\r]*[\n\r]", " ", x)
x <- gsub("//[^\n\r]*$", " ", x)
# Remove multi-line comments
x <- gsub("/\\*([^*]*(\\*[^/])?)*\\*/", " ", x)
# Standardize whitespace (including newlines)
x <- gsub("[[:space:]]+"," ", x)
trimws(x)
}
# check if the currently installed Stan version requires older syntax
# than the Stan version with which the model was initially fitted
require_old_stan_syntax <- function(object, backend, version) {
stopifnot(is.brmsfit(object))
isTRUE(
(object$backend == "rstan" && object$version$rstan >= version ||
object$backend == "cmdstanr" && object$version$cmdstan >= version) &&
(backend == "rstan" && utils::packageVersion("rstan") < version ||
backend == "cmdstanr" && cmdstanr::cmdstan_version() < version)
)
}