Skip to content

Commit

Permalink
extend scope of variables passed via 'data2'
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Feb 6, 2021
1 parent 12a7474 commit e45f64e
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 40 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.14.10
Date: 2021-02-04
Version: 2.14.11
Date: 2021-02-06
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
role = c("aut", "cre")),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ even if the required CDF or quantile functions are unavailable.
* Apply the R2-D2 shrinkage prior to population-level coefficients
via function `R2D2` to be used in `set_prior`.
* Extend support for `arma` correlation structures in non-normal families.
* Extend scope of variables passed via `data2` for use in the
evaluation of most model terms.

### Bug Fixes

Expand Down
9 changes: 6 additions & 3 deletions R/brm.R
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,17 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
)
family <- get_element(formula, "family")
bterms <- brmsterms(formula)
data_name <- substitute_name(data)
data <- validate_data(data, bterms = bterms, knots = knots)
attr(data, "data_name") <- data_name
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(formula),
get_data2_cov_ranef(formula)
)
data_name <- substitute_name(data)
data <- validate_data(
data, bterms = bterms,
data2 = data2, knots = knots
)
attr(data, "data_name") <- data_name
prior <- .validate_prior(
prior, bterms = bterms, data = data,
sample_prior = sample_prior
Expand Down
58 changes: 46 additions & 12 deletions R/data-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# of variable names; fixes issue #73
# @param knots: a list of knot values for GAMMs
# @return model.frame for use in brms functions
validate_data <- function(data, bterms, na.action = na.omit2,
drop.unused.levels = TRUE, knots = NULL,
validate_data <- function(data, bterms, data2 = NULL, knots = NULL,
na.action = na.omit2, drop.unused.levels = TRUE,
attr_terms = NULL) {
if (missing(data)) {
stop2("Data must be specified using the 'data' argument.")
Expand All @@ -26,18 +26,32 @@ validate_data <- function(data, bterms, na.action = na.omit2,
if (!isTRUE(nrow(data) > 0L)) {
stop2("Argument 'data' does not contain observations.")
}
all_terms <- terms(bterms$allvars)
attributes(all_terms)[names(attr_terms)] <- attr_terms
data <- data_rsv_intercept(data, bterms = bterms)
missing_vars <- setdiff(all_vars(all_terms), names(data))
all_vars_formula <- bterms$allvars
missing_vars <- setdiff(all_vars(all_vars_formula), names(data))
if (length(missing_vars)) {
stop2("The following variables are missing in 'data':\n",
collapse_comma(missing_vars))
}
missing_vars2 <- setdiff(missing_vars, names(data2))
if (length(missing_vars2)) {
stop2("The following variables can neither be found in ",
"'data' nor in 'data2':\n", collapse_comma(missing_vars2))
}
# all initially missing variables can be found in 'data2'
# they are not necessarily of the length required for 'data'
# so need to be excluded from the evaluation of 'model.frame'
missing_vars_formula <- paste0(". ~ . ", collapse(" - ", missing_vars))
all_vars_formula <- update(all_vars_formula, missing_vars_formula)
}
all_vars_terms <- terms(all_vars_formula)
# ensure that 'data2' comes first in the search path
# during the evaluation of model.frame
terms_env <- environment(all_vars_terms)
environment(all_vars_terms) <- as.environment(data2)
parent.env(environment(all_vars_terms)) <- terms_env
attributes(all_vars_terms)[names(attr_terms)] <- attr_terms
# 'terms' prevents correct validation in 'model.frame'
attr(data, "terms") <- NULL
data <- model.frame(
all_terms, data, na.action = na.pass,
all_vars_terms, data, na.action = na.pass,
drop.unused.levels = drop.unused.levels
)
data <- na.action(data, bterms = bterms)
Expand All @@ -64,6 +78,7 @@ validate_data <- function(data, bterms, na.action = na.omit2,
# @return a validated named list of data objects
validate_data2 <- function(data2, bterms, ...) {
# TODO: specify spline-related matrices in 'data2'
# this requires adding another parser layer with bterms and data as input
if (is.null(data2)) {
data2 <- list()
}
Expand Down Expand Up @@ -376,8 +391,8 @@ get_data_name <- function(data) {
#' @export
validate_newdata <- function(
newdata, object, re_formula = NULL, allow_new_levels = FALSE,
resp = NULL, check_response = TRUE, incl_autocor = TRUE,
all_group_vars = NULL, req_vars = NULL, ...
newdata2 = NULL, resp = NULL, check_response = TRUE,
incl_autocor = TRUE, all_group_vars = NULL, req_vars = NULL, ...
) {
newdata <- try(as.data.frame(newdata), silent = TRUE)
if (is(newdata, "try-error")) {
Expand Down Expand Up @@ -551,6 +566,7 @@ validate_newdata <- function(
newdata <- validate_data(
newdata, bterms = bterms, na.action = na.pass,
drop.unused.levels = FALSE, attr_terms = attr_terms,
data2 = current_data2(object, newdata2),
knots = get_knots(object$data)
)
newdata
Expand Down Expand Up @@ -585,7 +601,14 @@ fill_newdata <- function(newdata, vars, olddata = NULL, n = 1L) {
newdata
}

# extract the current data set
# validate new data2
validate_newdata2 <- function(newdata2, object, ...) {
stopifnot(is.brmsfit(object))
bterms <- brmsterms(object$formula)
validate_data2(newdata2, bterms = bterms, ...)
}

# extract the current data
current_data <- function(object, newdata = NULL, ...) {
stopifnot(is.brmsfit(object))
if (is.null(newdata)) {
Expand All @@ -595,3 +618,14 @@ current_data <- function(object, newdata = NULL, ...) {
}
data
}

# extract the current data2
current_data2 <- function(object, newdata2 = NULL, ...) {
stopifnot(is.brmsfit(object))
if (is.null(newdata2)) {
data2 <- object$data2
} else {
data2 <- validate_newdata2(newdata2, object = object, ...)
}
data2
}
12 changes: 10 additions & 2 deletions R/make_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#'
#' @export
make_stancode <- function(formula, data, family = gaussian(),
prior = NULL, autocor = NULL,
prior = NULL, autocor = NULL, data2 = NULL,
cov_ranef = NULL, sparse = NULL,
sample_prior = "no", stanvars = NULL,
stan_funs = NULL, knots = NULL,
Expand All @@ -33,7 +33,15 @@ make_stancode <- function(formula, data, family = gaussian(),
cov_ranef = cov_ranef
)
bterms <- brmsterms(formula)
data <- validate_data(data, bterms = bterms, knots = knots)
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(formula),
get_data2_cov_ranef(formula)
)
data <- validate_data(
data, bterms = bterms,
data2 = data2, knots = knots
)
prior <- .validate_prior(
prior, bterms = bterms, data = data,
sample_prior = sample_prior
Expand Down
35 changes: 18 additions & 17 deletions R/make_standata.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@ make_standata <- function(formula, data, family = gaussian(), prior = NULL,
autocor = autocor, cov_ranef = cov_ranef
)
bterms <- brmsterms(formula)
data <- validate_data(data, bterms = bterms, knots = knots)
prior <- .validate_prior(
prior, bterms = bterms, data = data,
sample_prior = sample_prior,
require_nlpar_prior = FALSE
)
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(formula),
get_data2_cov_ranef(formula)
)
data <- validate_data(
data, bterms = bterms,
knots = knots, data2 = data2
)
prior <- .validate_prior(
prior, bterms = bterms, data = data,
sample_prior = sample_prior,
require_nlpar_prior = FALSE
)
stanvars <- validate_stanvars(stanvars)
threads <- validate_threads(threads)
.make_standata(
Expand Down Expand Up @@ -137,18 +140,16 @@ standata.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
on.exit(options(.brmsfit_version = NULL))

object <- exclude_terms(object, incl_autocor = incl_autocor)
newdata2 <- use_alias(newdata2, new_objects)
formula <- update_re_terms(object$formula, re_formula)
bterms <- brmsterms(formula)
data <- current_data(object, newdata, re_formula = re_formula, ...)
stanvars <- object$stanvars
threads <- object$threads
if (is.null(newdata2)) {
data2 <- object$data2
} else {
data2 <- validate_data2(newdata2, bterms = bterms)
stanvars <- add_newdata_stanvars(stanvars, data2)
}

newdata2 <- use_alias(newdata2, new_objects)
data2 <- current_data2(object, newdata2)
data <- current_data(
object, newdata, newdata2 = data2,
re_formula = re_formula, ...
)
stanvars <- add_newdata_stanvars(object$stanvars, data2)

basis <- NULL
if (!is.null(newdata)) {
Expand All @@ -159,7 +160,7 @@ standata.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
.make_standata(
bterms, data = data, prior = object$prior,
data2 = data2, stanvars = stanvars,
threads = threads, basis = basis, ...
threads = object$threads, basis = basis, ...
)
}

Expand Down
20 changes: 16 additions & 4 deletions R/priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ prior_string <- function(prior, ...) {
#'
#' @export
get_prior <- function(formula, data, family = gaussian(), autocor = NULL,
knots = NULL, sparse = NULL, ...) {
data2 = NULL, knots = NULL, sparse = NULL, ...) {
if (is.brmsfit(formula)) {
stop2("Use 'prior_summary' to extract priors from 'brmsfit' objects.")
}
Expand All @@ -486,7 +486,14 @@ get_prior <- function(formula, data, family = gaussian(), autocor = NULL,
autocor = autocor, sparse = sparse
)
bterms <- brmsterms(formula)
data <- validate_data(data, bterms = bterms, knots = knots)
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(formula)
)
data <- validate_data(
data, bterms = bterms,
data2 = data2, knots = knots
)
.get_prior(bterms, data, ...)
}

Expand Down Expand Up @@ -1118,10 +1125,15 @@ def_scale_prior.brmsterms <- function(x, data, center = TRUE, df = 3,
#'
#' @export
validate_prior <- function(prior, formula, data, family = gaussian(),
sample_prior = "no", knots = NULL, ...) {
sample_prior = "no", data2 = NULL, knots = NULL,
...) {
formula <- validate_formula(formula, data = data, family = family)
bterms <- brmsterms(formula)
data <- validate_data(data, bterms = bterms, knots = knots)
data2 <- validate_data2(data2, bterms = bterms)
data <- validate_data(
data, bterms = bterms,
data2 = data2, knots = knots
)
.validate_prior(
prior, bterms = bterms, data = data,
sample_prior = sample_prior, ...
Expand Down
6 changes: 6 additions & 0 deletions man/get_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/make_stancode.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/validate_newdata.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/validate_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions tests/testthat/tests.make_standata.R
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,20 @@ test_that("information for threading is handled correctly", {
sdata <- make_standata(y ~ 1, dat, threads = threading(2, grainsize = 3))
expect_equal(sdata$grainsize, 3)
})

test_that("variables in data2 can be used in population-level effects", {
dat <- data.frame(y = 1:10, x1 = rnorm(10), x2 = rnorm(10), x3 = rnorm(10))
foo <- function(..., idx = NULL) {
out <- cbind(...)
if (!is.null(idx)) {
out <- out[, idx, drop = FALSE]
}
out
}
sdata <- make_standata(y ~ foo(x1, x2, x3, idx = id), data = dat,
data2 = list(id = c(3, 1)))
target <- c("Intercept", "foox1x2x3idxEQidx3", "foox1x2x3idxEQidx1")
expect_equal(colnames(sdata$X), target)
expect_equivalent(sdata$X[, 2], dat$x3)
expect_equivalent(sdata$X[, 3], dat$x1)
})

0 comments on commit e45f64e

Please sign in to comment.