Permalink
Browse files

refactor code related to the computation of LOO and WAIC

  • Loading branch information...
1 parent 0638b14 commit cad6398cca733e1a8198d13c6b96d3e985f1d3ca @paul-buerkner committed Jan 8, 2017
Showing with 181 additions and 79 deletions.
  1. +1 −0 NAMESPACE
  2. +62 −25 R/brmsfit-helpers.R
  3. +27 −21 R/brmsfit-methods.R
  4. +7 −6 R/generics.R
  5. +18 −12 R/misc-methods.R
  6. +5 −4 man/LOO.Rd
  7. +5 −4 man/WAIC.Rd
  8. +50 −0 man/compare_ic.Rd
  9. +6 −7 tests/testthat/tests.brmsfit-methods.R
View
@@ -79,6 +79,7 @@ export(brmdata)
export(brmsfamily)
export(brmsformula)
export(categorical)
+export(compare_ic)
export(cor.ar)
export(cor.arma)
export(cor.ma)
View
@@ -728,18 +728,71 @@ compute_ic <- function(x, ic = c("waic", "loo"), ll_args = list(), ...) {
}
IC <- do.call(eval(parse(text = paste0("loo::", ic))), args)
class(IC) <- c("ic", "loo")
- return(IC)
-}
-
-compare_ic <- function(x, ic = c("waic", "loo")) {
+ IC
+}
+
+#' Compare Information Criteria of Different Models
+#'
+#' Compare information criteria of different models fitted
+#' with \code{\link[brms:WAIC]{WAIC}} or \code{\link[brms:LOO]{LOO}}.
+#'
+#' @param ... At least two objects returned by
+#' \code{\link[brms:WAIC]{WAIC}} or \code{\link[brms:LOO]{LOO}}.
+#' @param x A list of at least two objects returned by
+#' \code{\link[brms:WAIC]{WAIC}} or \code{\link[brms:LOO]{LOO}}.
+#' This argument can be used as an alternative to specifying the
+#' models in \code{...}.
+#'
+#' @return An object of class \code{iclist}.
+#'
+#' @details For more details see \code{\link[loo:compare]{compare}}.
+#'
+#' @seealso
+#' \code{\link[brms:WAIC]{WAIC}},
+#' \code{\link[brms:LOO]{LOO}},
+#' \code{\link[loo:compare]{compare}}
+#'
+#' @examples
+#' \dontrun{
+#' # model with population-level effects only
+#' fit1 <- brm(rating ~ treat + period + carry,
+#' data = inhaler, family = "gaussian")
+#' w1 <- WAIC(fit1)
+#'
+#' # model with an additional varying intercept for subjects
+#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
+#' data = inhaler, family = "gaussian")
+#' w2 <- WAIC(fit2)
+#'
+#' # compare both models
+#' compare_ic(w1, w2)
+#' }
+#'
+#' @export
+compare_ic <- function(..., x = NULL) {
# compare information criteria of different models
# Args:
# x: A list containing loo objects
# ic: the information criterion to be computed
# Returns:
# A matrix with differences in the ICs
# as well as corresponding standard errors
- ic <- match.arg(ic)
+ if (!(is.null(x) || is.list(x))) {
+ stop2("Argument 'x' should be a list.")
+ }
+ x$ic_diffs__ <- NULL
+ x <- c(list(...), x)
+ if (!all(sapply(x, inherits, "ic"))) {
+ stop2("All inputs should have class 'ic'.")
+ }
+ if (length(x) < 2L) {
+ stop2("Expecting at least two objects.")
+ }
+ ics <- unname(sapply(x, function(y) names(y)[3]))
+ if (!all(sapply(ics, identical, ics[1]))) {
+ stop2("All inputs should be from the the same criterion.")
+ }
+ names(x) <- ulapply(x, "[[", "model_name")
n_models <- length(x)
ic_diffs <- matrix(0, nrow = n_models * (n_models - 1) / 2, ncol = 2)
rnames <- rep("", nrow(ic_diffs))
@@ -754,26 +807,10 @@ compare_ic <- function(x, ic = c("waic", "loo")) {
}
}
rownames(ic_diffs) <- rnames
- colnames(ic_diffs) <- c(toupper(ic), "SE")
- # compare all models at once to obtain weights
- all_compare <- do.call(loo::compare, x)
- if (n_models == 2L) {
- # weights are named differently when comparing only 2 models
- weights <- unname(all_compare[c("weight1", "weight2")])
- } else {
- # weights must be resorted as loo::compare sorts models after weights
- get_input_names <- function(...) {
- # mimic the way loo::compare defines model names
- as.character(match.call())[-1L]
- }
- if ("weights" %in% colnames(all_compare)) {
- weights <- unname(all_compare[do.call(get_input_names, x), "weights"])
- } else {
- # weights have been temporarily removed in loo 0.1.5
- weights <- rep(NA, n_models)
- }
- }
- nlist(ic_diffs, weights)
+ colnames(ic_diffs) <- c(toupper(ics[1]), "SE")
+ x$ic_diffs__ <- ic_diffs
+ class(x) <- c("iclist", "list")
+ x
}
set_pointwise <- function(x, newdata = NULL, subset = NULL, thres = 1e+08) {
View
@@ -2043,27 +2043,31 @@ WAIC.brmsfit <- function(x, ..., compare = TRUE, newdata = NULL,
re_formula = NULL, allow_new_levels = FALSE,
subset = NULL, nsamples = NULL, pointwise = NULL) {
models <- list(x, ...)
- names <- deparse(substitute(x))
- names <- c(names, sapply(substitute(list(...))[-1], deparse))
+ mnames <- deparse(substitute(x))
+ mnames <- c(mnames, sapply(substitute(list(...))[-1], deparse))
if (is.null(subset) && !is.null(nsamples)) {
subset <- sample(nsamples(x), nsamples)
}
if (is.null(pointwise)) {
pointwise <- set_pointwise(x, subset = subset, newdata = newdata)
}
ll_args = nlist(newdata, re_formula, allow_new_levels, subset, pointwise)
+ args <- nlist(ic = "waic", ll_args)
if (length(models) > 1L) {
- args <- nlist(X = models, FUN = compute_ic, ic = "waic", ll_args)
- out <- setNames(do.call(lapply, args), names)
- class(out) <- c("iclist", "list")
+ out <- named_list(mnames)
+ for (i in seq_along(models)) {
+ args[["x"]] <- models[[i]]
+ out[[i]] <- do.call(compute_ic, args)
+ out[[i]]$model_name <- mnames[i]
+ }
if (compare) {
match_response(models)
- comp <- compare_ic(out, ic = "waic")
- attr(out, "compare") <- comp$ic_diffs
- attr(out, "weights") <- comp$weights
+ out <- compare_ic(x = out)
}
- } else {
- out <- do.call(compute_ic, nlist(x, ic = "waic", ll_args))
+ class(out) <- c("iclist", "list")
+ } else {
+ out <- do.call(compute_ic, c(nlist(x), args))
+ out$model_name <- mnames
}
out
}
@@ -2086,29 +2090,31 @@ LOO.brmsfit <- function(x, ..., compare = TRUE, newdata = NULL,
subset = NULL, nsamples = NULL, pointwise = NULL,
cores = 1, wcp = 0.2, wtrunc = 3/4) {
models <- list(x, ...)
- names <- deparse(substitute(x))
- names <- c(names, sapply(substitute(list(...))[-1], deparse))
+ mnames <- deparse(substitute(x))
+ mnames <- c(mnames, sapply(substitute(list(...))[-1], deparse))
if (is.null(subset) && !is.null(nsamples)) {
subset <- sample(nsamples(x), nsamples)
}
if (is.null(pointwise)) {
pointwise <- set_pointwise(x, subset = subset, newdata = newdata)
}
ll_args = nlist(newdata, re_formula, allow_new_levels, subset, pointwise)
+ args <- nlist(ic = "loo", ll_args, wcp, wtrunc, cores)
if (length(models) > 1L) {
- args <- nlist(X = models, FUN = compute_ic, ic = "loo",
- ll_args, wcp, wtrunc, cores)
- out <- setNames(do.call(lapply, args), names)
- class(out) <- c("iclist", "list")
+ out <- named_list(mnames)
+ for (i in seq_along(models)) {
+ args[["x"]] <- models[[i]]
+ out[[i]] <- do.call(compute_ic, args)
+ out[[i]]$model_name <- mnames[i]
+ }
if (compare) {
match_response(models)
- comp <- compare_ic(out, ic = "loo")
- attr(out, "compare") <- comp$ic_diffs
- attr(out, "weights") <- comp$weights
+ out <- compare_ic(x = out)
}
+ class(out) <- c("iclist", "list")
} else {
- out <- do.call(compute_ic, nlist(x, ic = "loo", ll_args,
- wcp, wtrunc, cores))
+ out <- do.call(compute_ic, c(nlist(x), args))
+ out$model_name <- mnames
}
out
}
View
@@ -327,8 +327,9 @@ ngrps <- function(object, ...) {
#'
#' @param x A fitted model object typically of class \code{brmsfit}.
#' @param ... Optionally more fitted model objects.
-#' @param compare A flag indicating if the WAICs
-#' of the models should be compared to each other.
+#' @param compare A flag indicating if the information criteria
+#' of the models should be compared to each other
+#' via \code{\link[brms:compare_ic]{compare_ic}}.
#' @param pointwise A flag indicating whether to compute the full
#' log-likelihood matrix at once or separately for each observation.
#' The latter approach is usually considerably slower but
@@ -348,12 +349,12 @@ ngrps <- function(object, ...) {
#'
#' @examples
#' \dontrun{
-#' # model with fixed effects only
+#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#' data = inhaler, family = "gaussian")
#' WAIC(fit1)
#'
-#' # model with an additional random intercept for subjects
+#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "gaussian")
#' # compare both models
@@ -402,12 +403,12 @@ WAIC <- function(x, ...) {
#'
#' @examples
#' \dontrun{
-#' # model with fixed effects only
+#' # model with population-level effects only
#' fit1 <- brm(rating ~ treat + period + carry,
#' data = inhaler, family = "gaussian")
#' LOO(fit1)
#'
-#' # model with an additional random intercept for subjects
+#' # model with an additional varying intercept for subjects
#' fit2 <- brm(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "gaussian")
#' # compare both models
View
@@ -176,21 +176,27 @@ print.ic <- function(x, digits = 2, ...) {
#' @export
print.iclist <- function(x, digits = 2, ...) {
- # print the output of LOO(x1, x2, ...) and WAIC(x1, x2, ...)
- ic <- names(x[[1]])[3]
- mat <- matrix(0, nrow = length(x), ncol = 2,
- dimnames = list(names(x), c(toupper(ic), "SE")))
- for (i in 1:length(x)) {
- mat[i, ] <- c(x[[i]][[ic]], x[[i]][[paste0("se_",ic)]])
+ # print the output of LOO and WAIC with multiple models
+ m <- x
+ m$ic_diffs__ <- NULL
+ if (length(m)) {
+ ic <- names(m[[1]])[3]
+ mat <- matrix(0, nrow = length(m), ncol = 2)
+ dimnames(mat) <- list(names(m), c(toupper(ic), "SE"))
+ for (i in seq_along(m)) {
+ mat[i, ] <- c(m[[i]][[ic]], m[[i]][[paste0("se_", ic)]])
+ }
+ } else {
+ mat <- NULL
}
+ ic_diffs <- x$ic_diffs__
if (is.matrix(attr(x, "compare"))) {
+ # deprecated as of brms 1.4.0
+ ic_diffs <- attr(x, "compare")
+ }
+ if (is.matrix(ic_diffs)) {
# models were compared using the compare_ic function
- mat <- rbind(mat, attr(x, "compare"))
- weights <- c(attr(x, "weights"), rep(NA, nrow(attr(x, "compare"))))
- if (length(na.omit(weights))) {
- # no need to show the weights column if all weights are NA
- mat <- cbind(mat, Weights = weights)
- }
+ mat <- rbind(mat, ic_diffs)
}
print(round(mat, digits = digits), na.print = "")
invisible(x)
View
@@ -19,8 +19,9 @@ LOO(x, ...)
\item{...}{Optionally more fitted model objects.}
-\item{compare}{A flag indicating if the WAICs
-of the models should be compared to each other.}
+\item{compare}{A flag indicating if the information criteria
+of the models should be compared to each other
+via \code{\link[brms:compare_ic]{compare_ic}}.}
\item{newdata}{An optional data.frame for which to evaluate predictions.
If \code{NULL} (default), the orginal data of the model is used.}
@@ -77,12 +78,12 @@ When comparing models fitted to the same data,
}}
\examples{
\dontrun{
-# model with fixed effects only
+# model with population-level effects only
fit1 <- brm(rating ~ treat + period + carry,
data = inhaler, family = "gaussian")
LOO(fit1)
-# model with an additional random intercept for subjects
+# model with an additional varying intercept for subjects
fit2 <- brm(rating ~ treat + period + carry + (1|subject),
data = inhaler, family = "gaussian")
# compare both models
View
@@ -18,8 +18,9 @@ WAIC(x, ...)
\item{...}{Optionally more fitted model objects.}
-\item{compare}{A flag indicating if the WAICs
-of the models should be compared to each other.}
+\item{compare}{A flag indicating if the information criteria
+of the models should be compared to each other
+via \code{\link[brms:compare_ic]{compare_ic}}.}
\item{newdata}{An optional data.frame for which to evaluate predictions.
If \code{NULL} (default), the orginal data of the model is used.}
@@ -69,12 +70,12 @@ When comparing models fitted to the same data,
}}
\examples{
\dontrun{
-# model with fixed effects only
+# model with population-level effects only
fit1 <- brm(rating ~ treat + period + carry,
data = inhaler, family = "gaussian")
WAIC(fit1)
-# model with an additional random intercept for subjects
+# model with an additional varying intercept for subjects
fit2 <- brm(rating ~ treat + period + carry + (1|subject),
data = inhaler, family = "gaussian")
# compare both models
View
@@ -0,0 +1,50 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/brmsfit-helpers.R
+\name{compare_ic}
+\alias{compare_ic}
+\title{Compare Information Criteria of Different Models}
+\usage{
+compare_ic(..., x = NULL)
+}
+\arguments{
+\item{...}{At least two objects returned by
+\code{\link[brms:WAIC]{WAIC}} or \code{\link[brms:LOO]{LOO}}.}
+
+\item{x}{A list of at least two objects returned by
+\code{\link[brms:WAIC]{WAIC}} or \code{\link[brms:LOO]{LOO}}.
+This argument can be used as an alternative to specifying the
+models in \code{...}.}
+}
+\value{
+An object of class \code{iclist}.
+}
+\description{
+Compare information criteria of different models fitted
+with \code{\link[brms:WAIC]{WAIC}} or \code{\link[brms:LOO]{LOO}}.
+}
+\details{
+For more details see \code{\link[loo:compare]{compare}}.
+}
+\examples{
+\dontrun{
+# model with population-level effects only
+fit1 <- brm(rating ~ treat + period + carry,
+ data = inhaler, family = "gaussian")
+w1 <- WAIC(fit1)
+
+# model with an additional varying intercept for subjects
+fit2 <- brm(rating ~ treat + period + carry + (1|subject),
+ data = inhaler, family = "gaussian")
+w2 <- WAIC(fit2)
+
+# compare both models
+compare_ic(w1, w2)
+}
+
+}
+\seealso{
+\code{\link[brms:WAIC]{WAIC}},
+ \code{\link[brms:LOO]{LOO}},
+ \code{\link[loo:compare]{compare}}
+}
+
Oops, something went wrong.

0 comments on commit cad6398

Please sign in to comment.