Permalink
Browse files

correctly handle autocorrelated multivariate models in predict etc.

  • Loading branch information...
1 parent ff93efa commit 134cb4576c53884c5e9914fda89222a6a7561946 @paul-buerkner committed Jan 8, 2017
Showing with 41 additions and 25 deletions.
  1. +29 −1 R/brmsfit-helpers.R
  2. +10 −21 R/brmsfit-methods.R
  3. +1 −2 R/make_standata.R
  4. +1 −1 tests/testthat/tests.brmsfit-methods.R
View
@@ -669,6 +669,34 @@ prepare_family <- function(x) {
family
}
+reorder_obs <- function(eta, old_order = NULL, sort = FALSE) {
+ # reorder observations to be in the initial user-defined order
+ # currently only relevant for autocorrelation models
+ # Args:
+ # eta: Nsamples x Nobs matrix
+ # old_order: optional vector to retrieve the initial data order
+ # sort: keep the new order as defined by the time-series?
+ # Returns:
+ # eta with possibly reordered columns
+ if (!is.null(old_order) && !sort) {
+ N <- length(old_order)
+ if (ncol(eta) %% N != 0) {
+ # for compatibility with MV models fitted before brms 1.0.0
+ stopifnot(N %% ncol(eta) == 0)
+ old_order <- old_order[seq_len(ncol(eta))]
+ }
+ if (N < ncol(eta)) {
+ # should occur for multivariate models only
+ nresp <- ncol(eta) / N
+ old_order <- rep(old_order, nresp)
+ old_order <- old_order + rep(0:(nresp - 1) * N, each = N)
+ }
+ eta <- eta[, old_order, drop = FALSE]
+ colnames(eta) <- NULL
+ }
+ eta
+}
+
fixef_pars <- function() {
# regex to extract population-level coefficients
"^b(|cs|mo|me|m)_"
@@ -869,7 +897,7 @@ match_response <- function(models) {
out <- TRUE
} else {
out <- FALSE
- warning2("Model comparisons are most likely invalid as the response ",
+ warning2("Model comparisons are likely invalid as the response ",
"parts of at least two models do not match.")
}
}
View
@@ -1586,21 +1586,17 @@ predict.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
warning2(round(pct_invalid * 100), "% of all predicted values ",
"were invalid. Increasing argument 'ntrys' may help.")
}
+
# reorder predicted responses in case of multivariate models
# as they are sorted after units first not after traits
if (grepl("_mv$", draws$f$family)) {
nresp <- draws$data$nresp
- reorder <- ulapply(seq_len(nresp), seq, to = N*nresp, by = nresp)
+ reorder <- ulapply(seq_len(nresp), seq, to = N * nresp, by = nresp)
out <- out[, reorder, drop = FALSE]
- colnames(out) <- seq_len(ncol(out))
}
- # reorder predicted responses to be in the initial user defined order
- # currently only relevant for autocorrelation models
old_order <- attr(draws$data, "old_order")
- if (!is.null(old_order) && !sort) {
- out <- out[, old_order, drop = FALSE]
- colnames(out) <- NULL
- }
+ out <- reorder_obs(out, old_order, sort = sort)
+ colnames(out) <- NULL
# transform predicted response samples before summarizing them
is_catordinal <- is_ordinal(draws$f) || is_categorical(draws$f)
if (!is.null(transform) && !is_catordinal) {
@@ -1699,7 +1695,7 @@ fitted.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
mu <- get_eta(i = NULL, draws = draws)
if (grepl("_mv$", draws$f$family) && !is.null(draws[["mv"]])) {
# collapse over responses in linear MV models
- dim(mu) <- c(dim(mu)[1], prod(dim(mu)[2:3]))
+ dim(mu) <- c(dim(mu)[1], prod(dim(mu)[2:3]))
}
for (ap in intersect(auxpars(), names(draws))) {
if (is(draws[[ap]], "list")) {
@@ -1710,13 +1706,9 @@ fitted.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
# see fitted.R
mu <- fitted_response(draws = draws, mu = mu)
}
- # reorder fitted values to be in the initial user defined order
- # currently only relevant for autocorrelation models
old_order <- attr(draws$data, "old_order")
- if (!is.null(old_order) && !sort) {
- mu <- mu[, old_order, drop = FALSE]
- colnames(mu) <- NULL
- }
+ out <- reorder_obs(mu, old_order, sort = sort)
+ colnames(out) <- NULL
if (summary) {
mu <- get_summary(mu, probs = probs, robust = robust)
rownames(mu) <- seq_len(nrow(mu))
@@ -2178,13 +2170,10 @@ log_lik.brmsfit <- function(object, newdata = NULL, re_formula = NULL,
} else {
draws$eta <- get_eta(i = NULL, draws = draws)
loglik <- do.call(cbind, lapply(seq_len(N), loglik_fun, draws = draws))
- # reorder loglik values to be in the initial user defined order
- # currently only relevant for autocorrelation models
- # that are not using covariance formulation
old_order <- attr(draws$data, "old_order")
- if (!is.null(old_order) && !isTRUE(object$autocor$cov)) {
- loglik <- loglik[, old_order[seq_len(N)]]
- }
+ # do not loglik reorder for ARMA covariance models
+ sort <- use_cov(object$autocor)
+ loglik <- reorder_obs(loglik, old_order, sort = sort)
colnames(loglik) <- NULL
}
loglik
View
@@ -347,8 +347,7 @@ make_standata <- function(formula, data, family = "gaussian",
standata$Kma <- Kma
standata$Karma <- max(Kar, Kma)
if (use_cov(autocor)) {
- # Modeling ARMA effects using a special covariance matrix
- # requires additional data
+ # data for covariance matrices of ARMA effects
standata$N_tg <- length(unique(standata$tg))
standata$begin_tg <- as.array(with(standata,
ulapply(unique(tgroup), match, tgroup)))
@@ -394,7 +394,7 @@ test_that("all S3 methods have reasonable ouputs", {
expect_true(is.numeric(waic2[["waic"]]))
waic_pointwise <- SW(WAIC(fit2, pointwise = TRUE))
expect_equal(waic2, waic_pointwise)
- expect_warning(WAIC(fit1, fit2), "Model comparisons are most likely invalid")
+ expect_warning(WAIC(fit1, fit2), "Model comparisons are likely invalid")
waic4 <- SW(WAIC(fit4))
expect_true(is.numeric(waic4[["waic"]]))

0 comments on commit 134cb45

Please sign in to comment.