Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: transform get_prior, make_stancode and make_standata into S3 methods #1604

Merged
merged 8 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.20.13
Version: 2.20.14
Date: 2024-02-27
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
Expand Down Expand Up @@ -98,4 +98,4 @@ Additional_repositories:
VignetteBuilder:
knitr,
R.rsp
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ S3method(data_response,brmsterms)
S3method(data_response,mvbrmsterms)
S3method(def_scale_prior,brmsterms)
S3method(def_scale_prior,mvbrmsterms)
S3method(default_prior,brmsfit)
S3method(default_prior,default)
S3method(dpar_family,default)
S3method(dpar_family,mixfamily)
S3method(duplicated,brmsprior)
Expand Down Expand Up @@ -249,7 +251,9 @@ S3method(stan_predictor,btl)
S3method(stan_predictor,btnl)
S3method(stan_predictor,mvbrmsterms)
S3method(stancode,brmsfit)
S3method(stancode,default)
S3method(standata,brmsfit)
S3method(standata,default)
S3method(standata_basis,brmsterms)
S3method(standata_basis,btl)
S3method(standata_basis,btnl)
Expand Down Expand Up @@ -376,6 +380,7 @@ export(data_predictor)
export(data_response)
export(dbeta_binomial)
export(ddirichlet)
export(default_prior)
export(density_ratio)
export(dexgaussian)
export(dfrechet)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ if potentially results-changing arguments are provided to the criterion method.

### Other Changes

* Change `make_stancode` and `make_standata` to be aliases of `stancode` and
`standata`, respectively. Change `get_prior` to be an alias of a new generic
method `default_prior`. This enable other packages to define new `stancode`,
`standata` and `default_prior` methods to generate Stan code and data, and extract
the default priors, for their own objects building on brms. Thanks to Ven Popov
for helping with this. (#1604)
* No longer automatically canonicalize the Stan code if cmdstanr is used
as backend. (#1544)
* Improve parameter class names in the `summary` output.
Expand Down
24 changes: 12 additions & 12 deletions R/brm.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#' \code{family} might also be a list of families.
#' @param prior One or more \code{brmsprior} objects created by
#' \code{\link{set_prior}} or related functions and combined using the
#' \code{c} method or the \code{+} operator. See also \code{\link{get_prior}}
#' \code{c} method or the \code{+} operator. See also \code{\link[brms:default_prior.default]{default_prior}}
#' for more help.
#' @param data2 A named \code{list} of objects containing data, which
#' cannot be passed via argument \code{data}. Required for some objects
Expand Down Expand Up @@ -271,7 +271,7 @@
#' \code{\link[brms:set_prior]{set_prior}} function. Its documentation
#' contains detailed information on how to correctly specify priors. To find
#' out on which parameters or parameter classes priors can be defined, use
#' \code{\link[brms:get_prior]{get_prior}}. Default priors are chosen to be
#' \code{\link[brms:default_prior.default]{default_prior}}. Default priors are chosen to be
#' non or very weakly informative so that their influence on the results will
#' be negligible and you usually don't have to worry about them. However,
#' after getting more familiar with Bayesian statistics, I recommend you to
Expand Down Expand Up @@ -318,12 +318,12 @@
#' @examples
#' \dontrun{
#' # Poisson regression for the number of seizures in epileptic patients
#' # using normal priors for population-level effects
#' # and half-cauchy priors for standard deviations of group-level effects
#' prior1 <- prior(normal(0, 10), class = b) +
#' prior(cauchy(0, 2), class = sd)
#' fit1 <- brm(count ~ zBase * Trt + (1|patient), data = epilepsy,
#' family = poisson(), prior = prior1)
#' fit1 <- brm(
#' count ~ zBase * Trt + (1|patient),
#' data = epilepsy, family = poisson(),
#' prior = prior(normal(0, 10), class = b) +
#' prior(cauchy(0, 2), class = sd)
#' )
#'
#' # generate a summary of the results
#' summary(fit1)
Expand Down Expand Up @@ -418,8 +418,8 @@
#'
#'
#' # fit a model manually via rstan
#' scode <- make_stancode(count ~ Trt, data = epilepsy)
#' sdata <- make_standata(count ~ Trt, data = epilepsy)
#' scode <- stancode(count ~ Trt, data = epilepsy)
#' sdata <- standata(count ~ Trt, data = epilepsy)
#' stanfit <- rstan::stan(model_code = scode, data = sdata)
#' # feed the Stan model back into brms
#' fit8 <- brm(count ~ Trt, data = epilepsy, empty = TRUE)
Expand Down Expand Up @@ -537,7 +537,7 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
)
ranef <- tidy_ranef(bterms, data = data)
# generate Stan code
model <- .make_stancode(
model <- .stancode(
bterms, data = data, prior = prior,
stanvars = stanvars, save_model = save_model,
backend = backend, threads = threads, opencl = opencl,
Expand All @@ -556,7 +556,7 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
exclude <- exclude_pars(x)
# generate Stan data before compiling the model to avoid
# unnecessary compilations in case of invalid data
sdata <- .make_standata(
sdata <- .standata(
bterms, data = data, prior = prior, data2 = data2,
stanvars = stanvars, threads = threads
)
Expand Down
4 changes: 2 additions & 2 deletions R/brms-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#' formula syntax to specify a wide range of complex Bayesian models
#' (see \code{\link{brmsformula}} for details). Based on the supplied
#' formulas, data, and additional information, it writes the Stan code
#' on the fly via \code{\link{make_stancode}}, prepares the data via
#' \code{\link{make_standata}}, and fits the model using
#' on the fly via \code{\link[brms:stancode.default]{stancode}}, prepares the data via
#' \code{\link[brms:standata.default]{standata}} and fits the model using
#' \pkg{\link[rstan:rstan]{Stan}}.
#'
#' Subsequently, a large number of post-processing methods can be applied:
Expand Down
4 changes: 2 additions & 2 deletions R/brmsfit-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -871,9 +871,9 @@ validate_cores_post_processing <- function(cores) {
#' and scripts should not use it.
#'
#' @param fit Old \code{brmsfit} object (e.g., loaded from file).
#' @param sdata New Stan data (result of a call to \code{\link{make_standata}}).
#' @param sdata New Stan data (result of a call to \code{\link[brms:standata.default]{standata}}).
#' Pass \code{NULL} to avoid this data check.
#' @param scode New Stan code (result of a call to \code{\link{make_stancode}}).
#' @param scode New Stan code (result of a call to \code{\link[brms:stancode.default]{stancode}}).
#' Pass \code{NULL} to avoid this code check.
#' @param data New data to check consistency of factor level names.
#' Pass \code{NULL} to avoid this data check.
Expand Down
2 changes: 1 addition & 1 deletion R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ acat <- function(link = "logit", link_disc = "log",
#' pp_check(fit4)
#'
#' ## compare model fit
#' LOO(fit1, fit2, fit3, fit4)
#' loo(fit1, fit2, fit3, fit4)
#' }
#'
#' @export
Expand Down
2 changes: 1 addition & 1 deletion R/formula-gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
#' plot(me3, ask = FALSE, points = TRUE)
#'
#' # compare model fit
#' LOO(fit1, fit2, fit3)
#' loo(fit1, fit2, fit3)
#'
#' # simulate data with a factor covariate
#' dat2 <- mgcv::gamSim(4, n = 90, scale = 2)
Expand Down
132 changes: 85 additions & 47 deletions R/make_stancode.R
Original file line number Diff line number Diff line change
@@ -1,44 +1,89 @@
#' @title Stan Code for Bayesian models
#'
#' @description \code{stancode} is a generic function that can be used to
#' generate Stan code for Bayesian models. It's original use is
#' within the \pkg{brms} package, but new methods for use
#' with objects from other packages can be registered to the same generic.
#'
#' @param object An object whose class will determine which method to apply.
#' Usually, it will be some kind of symbolic description of the model
#' form which Stan code should be generated.
#' @param formula Synonym of \code{object} for use in \code{make_stancode}.
#' @param ... Further arguments passed to the specific method.
#'
#' @return Usually, a character string containing the generated Stan code.
#' For pretty printing, we recommend the returned object to be of class
#' \code{c("character", "brmsmodel")}.
#'
#' @details
#' See \code{\link[brms:stancode.default]{stancode.default}} for the default
#' method applied for \pkg{brms} models.
#' You can view the available methods by typing: \code{methods(stancode)}
#' The \code{make_stancode} function is an alias of \code{stancode}.
#'
#' @seealso
#' \code{\link{stancode.default}}, \code{\link{stancode.brmsfit}}
#'
#' @examples
#' stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#'
#' @export
stancode <- function(object, ...) {
UseMethod("stancode")
}

#' @rdname stancode
#' @export
make_stancode <- function(formula, ...) {
# became an alias of 'stancode' in 2.20.14
stancode(formula, ...)
}

#' Stan Code for \pkg{brms} Models
#'
#' Generate Stan code for \pkg{brms} models
#'
#' @inheritParams brm
#' @param object An object of class \code{\link[stats:formula]{formula}},
#' \code{\link{brmsformula}}, or \code{\link{mvbrmsformula}} (or one that can
#' be coerced to that classes): A symbolic description of the model to be
#' fitted. The details of model specification are explained in
#' \code{\link{brmsformula}}.
#' @param ... Other arguments for internal usage only.
#'
#' @return A character string containing the fully commented \pkg{Stan} code
#' to fit a \pkg{brms} model.
#' to fit a \pkg{brms} model. It is of class \code{c("character", "brmsmodel")}
#' to facilitate pretty printing.
#'
#' @examples
#' make_stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#' stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#'
#' make_stancode(count ~ zAge + zBase * Trt + (1|patient),
#' data = epilepsy, family = "poisson")
#' stancode(count ~ zAge + zBase * Trt + (1|patient),
#' data = epilepsy, family = "poisson")
#'
#' @export
make_stancode <- function(formula, data, family = gaussian(),
prior = NULL, autocor = NULL, data2 = NULL,
cov_ranef = NULL, sparse = NULL,
sample_prior = "no", stanvars = NULL,
stan_funs = NULL, knots = NULL,
drop_unused_levels = TRUE,
threads = getOption("brms.threads", NULL),
normalize = getOption("brms.normalize", TRUE),
save_model = NULL, ...) {
stancode.default <- function(object, data, family = gaussian(),
prior = NULL, autocor = NULL, data2 = NULL,
cov_ranef = NULL, sparse = NULL,
sample_prior = "no", stanvars = NULL,
stan_funs = NULL, knots = NULL,
drop_unused_levels = TRUE,
threads = getOption("brms.threads", NULL),
normalize = getOption("brms.normalize", TRUE),
save_model = NULL, ...) {

if (is.brmsfit(formula)) {
stop2("Use 'stancode' to extract Stan code from 'brmsfit' objects.")
}
formula <- validate_formula(
formula, data = data, family = family,
object <- validate_formula(
object, data = data, family = family,
autocor = autocor, sparse = sparse,
cov_ranef = cov_ranef
)
bterms <- brmsterms(formula)
bterms <- brmsterms(object)
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(formula),
get_data2_cov_ranef(formula)
get_data2_autocor(object),
get_data2_cov_ranef(object)
)
data <- validate_data(
data, bterms = bterms,
Expand All @@ -52,24 +97,25 @@ make_stancode <- function(formula, data, family = gaussian(),
stanvars <- validate_stanvars(stanvars, stan_funs = stan_funs)
threads <- validate_threads(threads)

.make_stancode(
.stancode(
bterms, data = data, prior = prior,
stanvars = stanvars, threads = threads,
normalize = normalize, save_model = save_model,
...
)
}

# internal work function of 'make_stancode'
# internal work function of 'stancode.default'
# @param parse parse the Stan model for automatic syntax checking
# @param backend name of the backend used for parsing
# @param silent silence parsing messages
.make_stancode <- function(bterms, data, prior, stanvars,
.stancode <- function(bterms, data, prior, stanvars,
threads = threading(),
normalize = getOption("brms.normalize", TRUE),
parse = getOption("brms.parse_stancode", FALSE),
backend = getOption("brms.backend", "rstan"),
silent = TRUE, save_model = NULL, ...) {

normalize <- as_one_logical(normalize)
parse <- as_one_logical(parse)
backend <- match.arg(backend, backend_choices())
Expand Down Expand Up @@ -329,31 +375,29 @@ print.brmsmodel <- function(x, ...) {
invisible(x)
}

#' Extract Stan model code
#'
#' Extract Stan code that was used to specify the model.
#' Extract Stan code from \code{brmsfit} objects
#'
#' @aliases stancode.brmsfit
#' Extract Stan code from a fitted \pkg{brms} model.
#'
#' @param object An object of class \code{brmsfit}.
#' @param version Logical; indicates if the first line containing
#' the \pkg{brms} version number should be included.
#' Defaults to \code{TRUE}.
#' @param regenerate Logical; indicates if the Stan code should
#' be regenerated with the current \pkg{brms} version.
#' By default, \code{regenerate} will be \code{FALSE} unless required
#' to be \code{TRUE} by other arguments.
#' @param threads Controls whether the Stan code should be threaded.
#' See \code{\link{threading}} for details.
#' @param version Logical; indicates if the first line containing the \pkg{brms}
#' version number should be included. Defaults to \code{TRUE}.
#' @param regenerate Logical; indicates if the Stan code should be regenerated
#' with the current \pkg{brms} version. By default, \code{regenerate} will be
#' \code{FALSE} unless required to be \code{TRUE} by other arguments.
#' @param threads Controls whether the Stan code should be threaded. See
#' \code{\link{threading}} for details.
#' @param backend Controls the Stan backend. See \code{\link{brm}} for details.
#' @param ... Further arguments passed to \code{\link{make_stancode}} if the
#' Stan code is regenerated.
#' @param ... Further arguments passed to
#' \code{\link[brms:stancode.default]{stancode}} if the Stan code is
#' regenerated.
#'
#' @return Stan model code for further processing.
#' @return Stan code for further processing.
#'
#' @export
stancode.brmsfit <- function(object, version = TRUE, regenerate = NULL,
threads = NULL, backend = NULL, ...) {

if (is.null(regenerate)) {
# determine whether regenerating the Stan code is required
regenerate <- FALSE
Expand Down Expand Up @@ -400,12 +444,6 @@ stancode.brmsfit <- function(object, version = TRUE, regenerate = NULL,
out
}

#' @rdname stancode.brmsfit
#' @export
stancode <- function(object, ...) {
UseMethod("stancode")
}

# expand '#include' statements
# This could also be done automatically by Stan at compilation time
# but would result in Stan code that is not self-contained until compilation
Expand Down