Skip to content

Commit

Permalink
Use correct dim size for BART predictions when chains = 1
Browse files Browse the repository at this point in the history
  • Loading branch information
vdorie committed Apr 27, 2023
1 parent 2e474c4 commit 6fc1a76
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 4 deletions.
2 changes: 2 additions & 0 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ predict.stan4bartFit <-
if (is.null(object$sampler.bart))
stop("predict for bart components requires 'bart_args' to contain 'keepTrees' as 'TRUE'")
indiv.bart <- .Call(C_stan4bart_predictBART, object$sampler.bart, testData$X.bart, NULL)
if (length(dim(indiv.bart)) == 2L)
dim(indiv.bart) <- c(dim(indiv.bart), 1L)
dimnames(indiv.bart) <- list(observation = NULL, sample = NULL, chain = NULL)
if (!is_bernoulli) for (i_chain in seq_len(n_chains)) {
indiv.bart[,,i_chain] <- object$range.bart["min",i_chain] +
Expand Down
24 changes: 24 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,27 @@ strip_extra_terms <- function(terms, extra_terms) {
terms
}

delete.weights <- function(termobj, weights)
{
a <- attributes(termobj)
termobj <- strip_extra_terms_from_language(termobj, weights)

w <- which(sapply(a$variables, "==", weights)) - 1L
if (length(w) == 1L && w > 0L) {
a$variables <- a$variables[-(1 + w)]
a$predvars <- a$predvars[-(1 + w)]
if (length(a$factors) > 0L)
a$factors <- a$factors[-w,,drop = FALSE]
if (length(a$offset) > 0L)
a$offset <- ifelse(a$offset > w, a$offset - 1, a$offset)
if (length(a$specials) > 0L) {
for (i in seq_along(a$specials)) {
b <- a$specials[[i]]
a$specials[[i]] <- ifelse(b > w, b - 1, b)
}
}
attributes(termobj) <- a
}

termobj
}
8 changes: 6 additions & 2 deletions R/stan4bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,12 @@ stan4bart <-
result$test <- testDataFrames

if (!is.null(offset_test)) result$test$offset <- offset_test
if (!is.null(weights) && length(weights) > 0L)
result$test$frame[["(weights)"]] <- with(result$test$frame, eval(mc$weights))
if (!is.null(weights) && length(weights) > 0L) {
try_result <- tryCatch(result$test$frame[["(weights)"]] <- with(result$test$frame, eval(mc$weights)),
error = function(e) e)
if (inherits(try_result, "error"))
warning("weights specified but not found in test data - ignoring")
}

result$bartData@x.test <- testDataFrames$X.bart
result$bartData@testUsesRegularOffset <- FALSE
Expand Down
26 changes: 24 additions & 2 deletions R/test_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,31 @@ getTestDataFrames <- function(object, newdata, na.action = na.pass, weights = NU
formula[[2L]] <- NULL
environment(formula) <- environment()

delete_weights <- FALSE
weights_name <- NULL
mf_call <- quote(stats::model.frame(formula = formula, data = newdata, na.action = "na.pass"))
if (!is.null(object$weights)) {
weights_name <- deparse(object$call$weights)
if (weights_name %in% names(object$frame) && weights_name %not_in% names(newdata)) {
delete_weights <- TRUE

formula <- strip_extra_terms_from_language(formula, weights_name)
environment(formula) <- environment()
}
}

result <- list(frame = eval(mf_call))

# define the sub-model frames as applicable
if (type %in% c("all", "fixed") && !is.null(object$X)) {
orig.fixed.levs <- get.orig.levs(object, type = "fixed")

terms <- delete.response(terms(object, type = "fixed"))
if (delete_weights)
terms <- delete.weights(terms, weights_name)

mf.fixed <- suppressWarnings(
model.frame(delete.response(terms(object, type = "fixed")), newdata,
model.frame(terms, newdata,
na.action = na.action, xlev = orig.fixed.levs)
)

Expand All @@ -27,9 +42,13 @@ getTestDataFrames <- function(object, newdata, na.action = na.pass, weights = NU

if (type %in% c("all", "bart")) {
orig.bart.levs <- attr(terms(object), "levels.bart")

terms <- delete.response(terms(object, type = "bart"))
if (delete_weights)
terms <- delete.weights(terms, weights_name)

mf.bart <- suppressWarnings(
model.frame(delete.response(terms(object, type = "bart")), newdata,
model.frame(terms, newdata,
na.action = na.action, xlev = orig.bart.levs)
)

Expand All @@ -42,6 +61,9 @@ getTestDataFrames <- function(object, newdata, na.action = na.pass, weights = NU
form.random <- formula(object, type = "random")

tt <- delete.response(terms.random)
if (delete_weights)
tt <- delete.weights(tt, weights_name)

frame.random <- model.frame(object, type = "random")
orig.random.levs <- get.orig.levs(object, newdata = newdata, type = "random")
sparse <- !is.null(orig.random.levs) && max(lengths(orig.random.levs)) > 100
Expand Down
2 changes: 2 additions & 0 deletions inst/NEWS.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
\itemize{
\item \code{extract} argument now works correctly with varcount. Bug
report thanks to Joshua Bon.
\item \code{predict} now works with single chains. Bug report thanks to
github user Pentaonia (Louis).
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-01-continuous.R
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,21 @@ test_that("predict matches supplied data", {
expect_equal(samples.pred, samples.ev)
})

test_that("predict works with one chain", {
df.train <- df[seq_len(floor(0.8 * nrow(df))),]
df.test <- df[seq.int(floor(0.8 * nrow(df)) + 1L, nrow(df)),]

fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2),
df.train,
test = df.test,
cores = 1, verbose = -1L, chains = 1, warmup = 7L, iter = 13L,
bart_args = list(n.trees = 11, keepTrees = TRUE))
expect_is(fit, "stan4bartFit")
predictions <- predict(fit, df.test)
expect_equal(dim(predictions), c(nrow(df.test), 13L - 7L))
expect_equal(names(dimnames(predictions)), c("observation", "iterations:chains"))
})

test_that("ppd has approximately right amount of noise", {
skip_if_not_installed("lme4")
df.train <- df
Expand Down

0 comments on commit 6fc1a76

Please sign in to comment.