Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compare --> loo_compare #93

Merged
merged 19 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ S3method(dim,waic)
S3method(loo,"function")
S3method(loo,array)
S3method(loo,matrix)
S3method(loo_compare,default)
S3method(loo_model_weights,default)
S3method(plot,loo)
S3method(plot,psis)
Expand All @@ -24,6 +25,7 @@ S3method(print,psis)
S3method(print,psis_loo)
S3method(print,stacking_weights)
S3method(print,waic)
S3method(print_dims,kfold)
S3method(print_dims,psis)
S3method(print_dims,psis_loo)
S3method(print_dims,waic)
Expand All @@ -43,14 +45,22 @@ export(compare)
export(example_loglik_array)
export(example_loglik_matrix)
export(extract_log_lik)
export(find_model_names)
export(gpdfit)
export(is.kfold)
export(is.loo)
export(is.psis)
export(is.psis_loo)
export(is.waic)
export(kfold)
export(kfold_split_grouped)
export(kfold_split_random)
export(kfold_split_stratified)
export(loo)
export(loo.array)
export(loo.function)
export(loo.matrix)
export(loo_compare)
export(loo_i)
export(loo_model_weights)
export(loo_model_weights.default)
Expand Down
8 changes: 4 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

* New vignette on LOO for non-factorizable joint Gaussian models. (#75)

* When comparing more than two models there is now also
an `se_diff` column in the results. (#78)
* New `se_diff` column in model comparison results. (#78)

* Fix for `psis()` when `log_ratios` are very small. (#74)
* Improved behavior of `psis()` when `log_ratios` are very small. (#74)

* Allow `r_eff=NA` to suppress warning when specifying `r_eff` is not applicable (i.e., draws not from MCMC). (#72)
* Allow `r_eff=NA` to suppress warning when specifying `r_eff` is not applicable
(i.e., draws not from MCMC). (#72)

* Update effective sample size calculations to match RStan's version. (#85)

Expand Down
33 changes: 3 additions & 30 deletions R/compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#' }
#'
compare <- function(..., x = list()) {
# .Deprecated("loo_compare")
dots <- list(...)
if (length(dots)) {
if (length(x)) {
Expand Down Expand Up @@ -87,6 +88,7 @@ compare <- function(..., x = list()) {
loo1 <- dots[[1]]
loo2 <- dots[[2]]
comp <- compare_two_models(loo1, loo2)
class(comp) <- c(class(comp), "old_compare.loo")
return(comp)
} else {
Ns <- sapply(dots, function(x) nrow(x$pointwise))
Expand Down Expand Up @@ -115,29 +117,11 @@ compare <- function(..., x = list()) {
se_diff <- apply(diffs, 2, se_elpd_diff)
comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
rownames(comp) <- rnms
class(comp) <- c("compare.loo", class(comp))
class(comp) <- c("compare.loo", class(comp), "old_compare.loo")
comp
}
}

#' @rdname compare
#' @export
#' @param digits For the print method only, the number of digits to use when
#' printing.
#' @param simplify For the print method only, should only the essential columns
#' of the summary matrix be printed when comparing more than two models? The
#' entire matrix is always returned, but by default only the most important
#' columns are printed.
print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) {
xcopy <- x
if (NCOL(xcopy) >= 2 && simplify) {
patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$"
xcopy <- xcopy[, grepl(patts, colnames(xcopy))]
}
print(.fr(xcopy, digits), quote = FALSE)
invisible(x)
}



# internal ----------------------------------------------------------------
Expand All @@ -154,14 +138,3 @@ compare_two_models <- function(loo_a, loo_b, return = c("elpd_diff", "se"), chec
comp <- c(elpd_diff = sum(diffs), se = se_elpd_diff(diffs))
structure(comp, class = "compare.loo")
}

elpd_diffs <- function(loo_a, loo_b) {
pt_a <- loo_a$pointwise
pt_b <- loo_b$pointwise
elpd <- grep("^elpd", colnames(pt_a))
pt_b[, elpd] - pt_a[, elpd]
}
se_elpd_diff <- function(diffs) {
N <- length(diffs)
sqrt(N) * sd(diffs)
}
15 changes: 0 additions & 15 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,6 @@ table_of_estimates <- function(x) {
}


# checking classes --------------------------------------------------------
is.psis <- function(x) {
inherits(x, "psis") && is.list(x)
}
is.loo <- function(x) {
inherits(x, "loo")
}
is.psis_loo <- function(x) {
inherits(x, "psis_loo") && is.loo(x)
}
is.waic <- function(x) {
inherits(x, "waic") && is.loo(x)
}


# validating and reshaping arrays/matrices -------------------------------

#' Check for NAs and non-finite values in log-lik (or log-ratios)
Expand Down
34 changes: 34 additions & 0 deletions R/kfold-generic.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#' Generic function for K-fold cross-validation for developers
#'
#' For developers of modeling packages, \pkg{loo} includes a generic function
#' \code{kfold} so that methods may be defined for K-fold CV without name
#' conflicts between packages. See, e.g., the \code{kfold.stanreg} method in
#' \pkg{rstanarm} and the \code{kfold.brmsfit} method in \pkg{brms}.
#'
#' @name kfold-generic
#' @param x A fitted model object.
#' @param ... Arguments to pass to specific methods.
#'
#' @return For developers defining a \code{kfold} method for a class
#' \code{"foo"}, the \code{kfold.foo} function should return a list with class
#' \code{c("kfold", "loo")} with at least the elements
#' \itemize{
#' \item \code{"estimates"}: a 1x2 matrix with column names "Estimate" and "SE"
#' containing the ELPD estimate and its standard error.
#' \item \code{"pointwise"}: an Nx1 matrix with column name "elpd_kfold" containing
#' the pointwise contributions for each data point.
#' }
#'
NULL

#' @rdname kfold-generic
#' @export
kfold <- function(x, ...) {
UseMethod("kfold")
}

#' @rdname kfold-generic
#' @export
is.kfold <- function(x) {
inherits(x, "kfold") && is.loo(x)
}
9 changes: 4 additions & 5 deletions R/kfold-helpers.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#' Helper functions for K-fold cross-validation
#'
#' These functions can be used to generate indexes for use with K-fold
#' cross-validation.
#' @description These functions can be used to generate indexes for use with
#' K-fold cross-validation. See the \strong{Details} section for explanations.
#'
#' @name kfold-helpers
#' @param K The number of folds to use.
#' @param N The number of observations in the data.
#' @param x A discrete variable of length \code{N} with at least \code{K} levels
#' (unique values). Will be coerced to \code{\link{factor}}.
#' .
#' @return An integer vector of length \code{N} where each element is an index
#' in \code{1:K}.
#'
#' @return An integer vector of length \code{N} where each element is an index in \code{1:K}.
#'
#' @details
#' \code{kfold_split_random} splits the data into \code{K} groups
Expand Down
12 changes: 12 additions & 0 deletions R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ dim.psis_loo <- function(x) {
}


#' @rdname loo
#' @export
is.loo <- function(x) {
inherits(x, "loo")
}

#' @rdname loo
#' @export
is.psis_loo <- function(x) {
inherits(x, "psis_loo") && is.loo(x)
}


# internal ----------------------------------------------------------------

Expand Down
178 changes: 178 additions & 0 deletions R/loo_compare.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#' Model comparison
#'
#' Compare fitted models on LOO or WAIC.
#'
#' @export
#' @param x An object of class \code{"loo"} or a list of such objects.
#' @param ... Additional objects of class \code{"loo"}.
#'
#' @return A matrix with class \code{"compare.loo"} that has its own
#' print method. See the \strong{Details} section for more .
#'
#' @details
#' When comparing two fitted models, we can estimate the difference in their
#' expected predictive accuracy by the difference in \code{elpd_loo} or
#' \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the
#' deviance scale).
#'
#' When using \code{loo_compare()}, the returned matrix will have one row per
#' model and several columns of estimates. The values in the \code{elpd_diff}
#' and \code{se_diff} columns of the returned matrix are computed by making
#' pairwise comparisons between each model and the model with the largest ELPD
#' (the model in the first row). For this reason the \code{elpd_diff} column
#' will always have the value \code{0} in the first row (i.e., the difference
#' between the preferred model and itself) and negative values in subsequent
#' rows for the remaining models.
#'
#' To compute the standard error of the difference in ELPD --- which should
#' not be expected to equal the difference of the standard errors --- we use a
#' paired estimate to take advantage of the fact that the same set of \eqn{N}
#' data points was used to fit both models. These calculations should be most
#' useful when \eqn{N} is large, because then non-normality of the
#' distribution is not such an issue when estimating the uncertainty in these
#' sums. These standard errors, for all their flaws, should give a better
#' sense of uncertainty than what is obtained using the current standard
#' approach of comparing differences of deviances to a Chi-squared
#' distribution, a practice derived for Gaussian linear models or
#' asymptotically, and which only applies to nested models in any case.
#'
#' @template loo-and-psis-references
#'
#' @examples
#' \dontrun{
#' loo1 <- loo(log_lik1)
#' loo2 <- loo(log_lik2)
#' print(loo_compare(loo1, loo2), digits = 3)
#' print(loo_compare(x = list(loo1, loo2)))
#'
#' waic1 <- waic(log_lik1)
#' waic2 <- waic(log_lik2)
#' loo_compare(waic1, waic2)
#' }
#'
loo_compare <- function(x, ...) {
UseMethod("loo_compare")
}

#' @rdname loo_compare
#' @export
loo_compare.default <- function(x, ...) {
if (is.loo(x)) {
dots <- list(...)
loos <- c(list(x), dots)
} else {
if (!is.list(x) || !length(x)) {
stop("'x' must be a list if not a 'loo' object.")
}
if (length(list(...))) {
stop("If 'x' is a list then '...' should not be specified.")
}
loos <- x
}

if (!all(sapply(loos, is.loo))) {
stop("All inputs should have class 'loo'.")
}
if (length(loos) <= 1L) {
stop("'loo_compare' requires at least two models.")
}

Ns <- sapply(loos, function(x) nrow(x$pointwise))
if (!all(Ns == Ns[1L])) {
stop("Not all models have the same number of data points.")
}

tmp <- sapply(loos, function(x) {
est <- x$estimates
setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))) )
})

colnames(tmp) <- find_model_names(loos)
rnms <- rownames(tmp)
comp <- tmp
ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE)
comp <- t(comp)[ord, ]
patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$")
col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))),
use.names = FALSE)
comp <- comp[, col_ord]

# compute elpd_diff and se_elpd_diff relative to best model
rnms <- rownames(comp)
diffs <- mapply(FUN = elpd_diffs, loos[ord[1]], loos[ord])
elpd_diff <- apply(diffs, 2, sum)
se_diff <- apply(diffs, 2, se_elpd_diff)
comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
rownames(comp) <- rnms

class(comp) <- c("compare.loo", class(comp))
return(comp)
}

#' @rdname loo_compare
#' @export
#' @param digits For the print method only, the number of digits to use when
#' printing.
#' @param simplify For the print method only, should only the essential columns
#' of the summary matrix be printed? The entire matrix is always returned, but
#' by default only the most important columns are printed.
print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) {
xcopy <- x
if (inherits(xcopy, "old_compare.loo")) {
if (NCOL(xcopy) >= 2 && simplify) {
patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$"
xcopy <- xcopy[, grepl(patts, colnames(xcopy))]
}
} else if (NCOL(xcopy) >= 2 && simplify) {
xcopy <- xcopy[, c("elpd_diff", "se_diff")]
}
print(.fr(xcopy, digits), quote = FALSE)
invisible(x)
}



# internal ----------------------------------------------------------------
elpd_diffs <- function(loo_a, loo_b) {
pt_a <- loo_a$pointwise
pt_b <- loo_b$pointwise
elpd <- grep("^elpd", colnames(pt_a))
pt_b[, elpd] - pt_a[, elpd]
}
se_elpd_diff <- function(diffs) {
N <- length(diffs)
sqrt(N) * sd(diffs)
}



#' Find the model names associated with loo objects
#'
#' @export
#' @keywords internal
#' @param x List of loo objects.
#' @return Character vector of model names the same length as x.
#'
find_model_names <- function(x) {
stopifnot(is.list(x))
out_names <- character(length(x))

names1 <- names(x)
names2 <- lapply(x, "attr", "model_name", exact = TRUE)
names3 <- lapply(x, "[[", "model_name")
names4 <- paste0("model", seq_along(x))

for (j in seq_along(x)) {
if (isTRUE(nzchar(names1[j]))) {
out_names[j] <- names1[j]
} else if (length(names2[[j]])) {
out_names[j] <- names2[[j]]
} else if (length(names3[[j]])) {
out_names[j] <- names3[[j]]
} else {
out_names[j] <- names4[j]
}
}

return(out_names)
}
Loading