Skip to content

Commit

Permalink
feature issue #510
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Aug 28, 2018
1 parent 8800f32 commit e12c861
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 71 deletions.
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ argument `nlpar` in method `fitted`.
* Disable automatic cell-mean coding in model formulas without
an intercept via argument `cmc` of `brmsformula` and related
functions thanks to Marie Beisemann.
* Improve method `kfold` to offer more options for specifying
omitted subsets. (#510)

### Other changes

* Ignore argument `resp` when post-processing
univariate models thanks to Ruben Arslan. (#488)
* Deprecate argument `ordinal` in `marginal_effects`. (#491)
* Deprecate argument `ordinal` of `marginal_effects`. (#491)
* Deprecate argument `exact_loo` of `kfold`. (#510)
* Deprecate usage of `binomial` families without specifying `trials`.

### Bug fixes
Expand Down
10 changes: 7 additions & 3 deletions R/brmsfit-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -2656,12 +2656,16 @@ loo.brmsfit <- function(x, ..., compare = TRUE, resp = NULL,
#' @export
#' @describeIn kfold \code{kfold} method for \code{brmsfit} objects
kfold.brmsfit <- function(x, ..., compare = TRUE, K = 10, Ksub = NULL,
exact_loo = FALSE, group = NULL, resp = NULL,
model_names = NULL, save_fits = FALSE) {
folds = NULL, group = NULL, exact_loo = NULL,
resp = NULL, model_names = NULL, save_fits = FALSE) {
args <- split_dots(x, ..., model_names = model_names)
use_stored_ic <- ulapply(args$models, function(x) is_equal(x$kfold$K, K))
if (!is.null(exact_loo) && as_one_logical(exact_loo)) {
warning2("'exact_loo' is deprecated. Please use folds = 'loo' instead.")
folds <- "loo"
}
c(args) <- nlist(
ic = "kfold", compare, K, Ksub, exact_loo,
ic = "kfold", compare, K, Ksub, folds,
group, resp, save_fits, use_stored_ic
)
do.call(compute_ics, args)
Expand Down
135 changes: 92 additions & 43 deletions R/loo-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ WAIC <- function(x, ...) {
#'
#' @inheritParams loo.brmsfit
#' @param K The number of subsets of equal (if possible) size
#' into which the data will be randomly partitioned for performing
#' into which the data will be partitioned for performing
#' \eqn{K}-fold cross-validation. The model is refit \code{K} times, each time
#' leaving out one of the \code{K} subsets. If \code{K} is equal to the total
#' number of observations in the data then \eqn{K}-fold cross-validation is
Expand All @@ -28,19 +28,16 @@ WAIC <- function(x, ...) {
#' (created via \code{as.array}) potentially of length one, the corresponding
#' subsets will be used. This argument is primarily useful, if evaluation of
#' all subsets is infeasible for some reason.
#' @param exact_loo Logical; If \code{TRUE}, exact leave-one-out cross-validation
#' will be performed and \code{K} will be ignored. This argument alters
#' the way argument \code{group} is handled as described below.
#' Defaults to \code{FALSE}.
#' @param folds Determines how the subsets are being constructed.
#' Possible values are \code{NULL} (the default), \code{"stratified"},
#' \code{"balanced"}, or \code{"loo"}. May also be a vector of length
#' equal to the number of observations in the data. Alters the way
#' \code{group} is handled. More information is provided in the 'Details'
#' section.
#' @param group Optional name of a grouping variable or factor in the model.
#' How this variable is handled depends on argument \code{exact_loo}.
#' If \code{exact_loo} is \code{FALSE}, the data is split
#' up into subsets, each time omitting all observations of one of the
#' factor levels, while ignoring argument \code{K}.
#' If \code{exact_loo} is \code{TRUE}, all observations corresponding
#' to the factor level of the currently predicted single value are omitted.
#' Thus, in this case, the predicted values are only a subset of the
#' omitted ones.
#' What exactly is done with this variable depends on argument \code{folds}.
#' More information is provided in the 'Details' section.
#' @param exact_loo Deprecated! Please use \code{folds = "loo"} instead.
#' @param save_fits If \code{TRUE}, a component \code{fits} is added to
#' the returned object to store the cross-validated \code{brmsfit}
#' objects and the indices of the omitted observations for each fold.
Expand All @@ -50,14 +47,40 @@ WAIC <- function(x, ...) {
#' objects returned by the \code{loo} and \code{waic} methods.
#'
#' @details The \code{kfold} function performs exact \eqn{K}-fold
#' cross-validation. First the data are randomly partitioned into \eqn{K}
#' subsets of equal (or as close to equal as possible) size. Then the model is
#' refit \eqn{K} times, each time leaving out one of the \code{K} subsets. If
#' \eqn{K} is equal to the total number of observations in the data then
#' \eqn{K}-fold cross-validation is equivalent to exact leave-one-out
#' cross-validation (to which \code{loo} is an efficient approximation). The
#' \code{compare_ic} function is also compatible with the objects returned
#' by \code{kfold}.
#' cross-validation. First the data are partitioned into \eqn{K} folds
#' (i.e. subsets) of equal (or as close to equal as possible) size by default.
#' Then the model is refit \eqn{K} times, each time leaving out one of the
#' \code{K} subsets. If \eqn{K} is equal to the total number of observations
#' in the data then \eqn{K}-fold cross-validation is equivalent to exact
#' leave-one-out cross-validation (to which \code{loo} is an efficient
#' approximation). The \code{compare_ic} function is also compatible with
#' the objects returned by \code{kfold}.
#'
#' The subsets can be constructed in multiple different ways:
#' \itemize{
#' \item If both \code{folds} and \code{group} are \code{NULL}, the subsets
#' are randomly chosen so that they have equal (or as close to equal as
#' possible) size.
#' \item If \code{folds} is \code{NULL} but \code{group} is specified, the
#' data is split up into subsets, each time omitting all observations of one
#' of the factor levels, while ignoring argument \code{K}.
#' \item If \code{folds = "stratified"} the subsets are stratified after
#' \code{group} using \code{\link[loo:kfold-helpers]{loo::kfold_split_stratified}}.
#' \item If \code{folds = "balanced"} the subsets are balanced by
#' \code{group} using \code{\link[loo:kfold-helpers]{loo::kfold_split_balanced}}.
#' \item If \code{folds = "loo"} exact leave-one-out cross-validation
#' will be performed and \code{K} will be ignored. Further, if \code{group}
#' is specified, all observations corresponding to the factor level of the
#' currently predicted single value are omitted. Thus, in this case, the
#' predicted values are only a subset of the omitted ones.
#' \item If \code{folds} is a numeric vector, it must contain one element per
#' observation in the data. Each element of the vector is an integer in
#' \code{1:K} indicating to which of the \code{K} folds the corresponding
#' observation belongs. There are some convenience functions available in
#' the \pkg{loo} package that create integer vectors to use for this purpose
#' (see the Examples section below and also the
#' \link[loo:kfold-helpers]{kfold-helpers} page).
#' }
#'
#' @examples
#' \dontrun{
Expand All @@ -67,7 +90,7 @@ WAIC <- function(x, ...) {
#' # throws warning about some pareto k estimates being too high
#' (loo1 <- loo(fit1))
#' # perform 10-fold cross validation
#' (kfold1 <- kfold(fit1, chains = 2, cores = 2)
#' (kfold1 <- kfold(fit1, chains = 1)
#' }
#'
#' @seealso \code{\link{loo}}, \code{\link{reloo}}
Expand Down Expand Up @@ -587,7 +610,7 @@ reloo.loo <- function(x, fit, k_threshold = 0.7, check = TRUE,
x
}

kfold_internal <- function(x, K = 10, Ksub = NULL, exact_loo = FALSE,
kfold_internal <- function(x, K = 10, Ksub = NULL, folds = NULL,
group = NULL, newdata = NULL, resp = NULL,
save_fits = FALSE, ...) {
# most of the code is taken from rstanarm::kfold
Expand All @@ -601,31 +624,54 @@ kfold_internal <- function(x, K = 10, Ksub = NULL, exact_loo = FALSE,
mf <- as.data.frame(newdata)
}
N <- nrow(mf)
if (exact_loo) {
K <- N
message("Setting 'K' to the number of observations (", K, ")")
}
if (is.null(group)) {
if (K < 1 || K > N) {
stop2("'K' must be greater than one and smaller or ",
"equal to the number of observations in the model.")
}
bin <- loo::kfold_split_random(K, N)
} else {
# validate argument 'group'
# validate argument 'group'
if (!is.null(group)) {
valid_groups <- get_cat_vars(x)
if (length(group) != 1L || !group %in% valid_groups) {
stop2("Group '", group, "' is not a valid grouping factor. ",
"Valid groups are: \n", collapse_comma(valid_groups))
}
gvar <- factor(get(group, mf))
bin <- as.numeric(gvar)
if (!exact_loo) {
# K was already set to N if exact_loo is TRUE
}
# validate argument 'folds'
if (is.null(folds)) {
if (is.null(group)) {
fold_type <- "random"
folds <- loo::kfold_split_random(K, N)
} else {
fold_type <- "group"
folds <- as.numeric(gvar)
K <- length(levels(gvar))
message("Setting 'K' to the number of levels of '", group, "' (", K, ")")
message("Setting 'K' to the number of levels of '", group, "' (", K, ")")
}
} else if (is.numeric(folds) || length(folds) > 1L) {
fold_type <- "custom"
folds <- as.numeric(factor(folds))
if (length(folds) != N) {
stop2("If 'folds' is a vector, it must be of length N.")
}
K <- max(folds)
message("Setting 'K' to the number of folds (", K, ")")
} else {
opts <- c("stratified", "balanced", "loo")
fold_type <- match.arg(folds, opts)
if (fold_type == "loo") {
folds <- seq_len(N)
K <- N
message("Setting 'K' to the number of observations (", K, ")")
} else if (fold_type == "stratified") {
if (is.null(group)) {
stop2("Argument 'group' is required for stratified folds.")
}
folds <- loo::kfold_split_stratified(K, gvar)
} else if (fold_type == "balanced") {
if (is.null(group)) {
stop2("Argument 'group' is required for balanced folds.")
}
folds <- loo::kfold_split_balanced(K, gvar)
}
}
# validate argument 'Ksub'
if (is.null(Ksub)) {
Ksub <- seq_len(K)
} else {
Expand All @@ -642,6 +688,7 @@ kfold_internal <- function(x, K = 10, Ksub = NULL, exact_loo = FALSE,
}
Ksub <- sort(Ksub)
}

if (save_fits) {
fits <- array(
list(), dim = c(length(Ksub), 2),
Expand All @@ -652,11 +699,11 @@ kfold_internal <- function(x, K = 10, Ksub = NULL, exact_loo = FALSE,
for (k in Ksub) {
message("Fitting model ", k, " out of ", K)
ks <- match(k, Ksub)
if (exact_loo && !is.null(group)) {
omitted <- which(bin == bin[k])
if (fold_type == "loo" && !is.null(group)) {
omitted <- which(folds == folds[k])
predicted <- k
} else {
omitted <- predicted <- which(bin == k)
omitted <- predicted <- which(folds == k)
}
mf_omitted <- mf[-omitted, , drop = FALSE]
fit_k <- subset_autocor(x, -omitted)
Expand All @@ -671,6 +718,7 @@ kfold_internal <- function(x, K = 10, Ksub = NULL, exact_loo = FALSE,
allow_new_levels = TRUE, resp = resp
)
}

elpds <- ulapply(lppds, function(x) apply(x, 2, log_mean_exp))
# make sure elpds are put back in the right order
elpds <- elpds[order(unlist(obs_order))]
Expand All @@ -683,7 +731,8 @@ kfold_internal <- function(x, K = 10, Ksub = NULL, exact_loo = FALSE,
estimates[3, ] <- c(- 2 * elpd_kfold, 2 * se_elpd_kfold)
out <- nlist(
estimates, pointwise = cbind(elpd_kfold = elpds),
model_name = deparse(substitute(x)), K, Ksub, exact_loo, group
model_name = deparse(substitute(x)), K, Ksub,
group, folds, fold_type
)
if (save_fits) {
out$fits <- fits
Expand Down
2 changes: 2 additions & 0 deletions brms.Rproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ PackageBuildArgs: --no-build-vignettes
PackageBuildBinaryArgs: --preclean
PackageCheckArgs: --as-cran --ignore-vignettes
PackageRoxygenize: rd,namespace

QuitChildProcessesOnExit: Yes
72 changes: 48 additions & 24 deletions man/kfold.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e12c861

Please sign in to comment.