Skip to content

Commit

Permalink
Merge pull request #1459 from fweber144/newdata_issue
Browse files Browse the repository at this point in the history
Fix #1457 (projpred `newdata` requiring more variables than necessary)
  • Loading branch information
paul-buerkner committed Feb 15, 2023
2 parents b31fcee + d3bf58c commit c793d9f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.18.7
Date: 2023-01-17
Version: 2.18.8
Date: 2023-02-14
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
role = c("aut", "cre")),
Expand Down
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ via the `hurdle_cumulative` family thanks to Stephen Wild. (#1448)
in post-processing methods that require a compiled Stan model.
* Extend control over the `point_estimate` feature in `prepare_predictions`
via the new argument `ndraws_point_estimate`.
* Add support for the latent projection available in **projpred** versions >=
2.4.0.
* Add support for the latent projection available in
**projpred** versions >= 2.4.0. (#1451)

### Bug Fixes

* Fix a Stan syntax error in threaded models with `lasso` priors. (#1427)
* Fix Stan compilation issues for some of the more special
link functions such as `cauchit` or `softplus`.
* Fix a bug for predictions in **projpred**, previously requiring more variables
in `newdata` than necessary. (#1457, #1459)


# brms 2.18.0
Expand Down
45 changes: 30 additions & 15 deletions R/projpred.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#' \code{\link[projpred:init_refmodel]{init_refmodel}}.
#'
#' @details Note that the \code{extract_model_data} function used internally by
#' \code{get_refmodel.brmsfit} ignores arguments \code{wrhs}, \code{orhs}, and
#' \code{extract_y}. This is relevant for
#' \code{get_refmodel.brmsfit} ignores arguments \code{wrhs} and \code{orhs}.
#' This is relevant for
#' \code{\link[projpred:predict.refmodel]{predict.refmodel}}, for example.
#'
#' @return A \code{refmodel} object to be used in conjunction with the
Expand Down Expand Up @@ -225,7 +225,8 @@ get_refmodel.brmsfit <- function(object, newdata = NULL, resp = NULL,

# auxiliary data required in predictions via projpred
# @return a named list with slots 'y', 'weights', and 'offset'
.extract_model_data <- function(object, newdata = NULL, resp = NULL, ...) {
.extract_model_data <- function(object, newdata = NULL, resp = NULL,
extract_y = TRUE, ...) {
stopifnot(is.brmsfit(object))
resp <- validate_resp(resp, object, multiple = FALSE)

Expand All @@ -235,28 +236,42 @@ get_refmodel.brmsfit <- function(object, newdata = NULL, resp = NULL,
if (!is.null(resp)) {
formula <- formula$forms[[resp]]
}
respform <- brmsterms(formula)$respform
data <- current_data(
object, newdata, resp = resp, check_response = TRUE,
allow_new_levels = TRUE
)
y <- unname(model.response(model.frame(respform, data, na.action = na.pass)))
bterms <- brmsterms(formula)
y <- NULL
if (extract_y) {
data <- current_data(
object, newdata, resp = resp, check_response = TRUE,
allow_new_levels = TRUE, req_vars = character()
)
y <- model.response(model.frame(bterms$respform, data, na.action = na.pass))
y <- unname(y)
}

# extract relevant auxiliary data
# extract relevant auxiliary data (offsets and weights (or numbers of trials))
# call standata to ensure the correct format of the data
# For this, we use `check_response = FALSE` and only include offsets and
# weights (or numbers of trials) in `req_vars` because of issue #1457 (note
# that all.vars(NULL) gives character(0), as desired).
req_vars <- unlist(lapply(bterms$dpars, function(x) all.vars(x[["offset"]])))
req_vars <- unique(req_vars)
c(req_vars) <- all.vars(bterms$adforms$weights)
c(req_vars) <- all.vars(bterms$adforms$trials)
args <- nlist(
object, newdata, resp,
allow_new_levels = TRUE,
check_response = TRUE,
internal = TRUE
check_response = FALSE,
internal = TRUE,
req_vars = req_vars
)
# NOTE: Missing weights don't cause an error here (see #1459)
sdata <- do_call(standata, args)

usc_resp <- usc(resp)
N <- sdata[[paste0("N", usc_resp)]]
weights <- as.vector(sdata[[paste0("weights", usc_resp)]])
trials <- as.vector(sdata[[paste0("trials", usc_resp)]])
if (is_binary(formula)) {
trials <- rep(1, length(y))
trials <- rep(1, N)
}
if (!is.null(trials)) {
if (!is.null(weights)) {
Expand All @@ -265,11 +280,11 @@ get_refmodel.brmsfit <- function(object, newdata = NULL, resp = NULL,
weights <- trials
}
if (is.null(weights)) {
weights <- rep(1, length(y))
weights <- rep(1, N)
}
offset <- as.vector(sdata[[paste0("offsets", usc_resp)]])
if (is.null(offset)) {
offset <- rep(0, length(y))
offset <- rep(0, N)
}
nlist(y, weights, offset)
}
Expand Down
4 changes: 2 additions & 2 deletions man/get_refmodel.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c793d9f

Please sign in to comment.