Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New argument folds of run_cvfun() #480

Merged
merged 7 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ If you read this from a place other than <https://mc-stan.org/projpred/news/inde
* Added a new performance statistic, the geometric mean predictive density (GMPD). This is particularly useful for discrete outcomes because there, the GMPD is a geometric mean of probabilities and hence bounded by zero and one. For details, see argument `stats` of the `?summary.vsel` help. (GitHub: #476)
* `project()`'s argument `verbose` now gets passed to argument `verbose_divmin` (not `projpred_verbose`) of the divergence minimizer function (see argument `div_minimizer` of `init_refmodel()`).
* Arguments `lambda_min_ratio`, `nlambda`, and `thresh` of `varsel()` and `cv_varsel()` have been deprecated. Instead, `varsel()` and `cv_varsel()` have gained a new argument called `search_control` which accepts control arguments for the search as a `list`. Thus, former arguments `lambda_min_ratio`, `nlambda`, and `thresh` should now be specified via `search_control` (but note that `search_control` is more general because it also accepts control arguments for a *forward* search). (GitHub: #477)
* `run_cvfun()` has gained a new argument `folds`, accepting a vector of fold indices (the default is `NULL`, meaning that the folds are constructed internally, as before). This new argument is helpful, for example, to perform a stratified K-fold CV in a convenient manner (an example of this has been added to the `?run_cvfun` help). (GitHub: #480)

## Bug fixes

Expand Down
29 changes: 26 additions & 3 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,11 @@ get_kfold <- function(refmodel, K, cvfits, verbose) {
#' [init_refmodel()]) or an object that can be passed to argument `object` of
#' [get_refmodel()].
#' @param K Number of folds. Must be at least 2 and not exceed the number of
#' observations.
#' observations. Ignored if `folds` is not `NULL`.
#' @param folds Either `NULL` for determining the CV folds automatically via
#' [cv_folds()] (using argument `K`) or a numeric (in fact, integer) vector
#' giving the fold index for each observation. In the latter case, argument
#' `K` is ignored.
#' @param seed Pseudorandom number generation (PRNG) seed by which the same
#' results can be obtained again if needed. Passed to argument `seed` of
#' [set.seed()], but can also be `NA` to not call [set.seed()] at all. If not
Expand Down Expand Up @@ -1513,6 +1517,23 @@ get_kfold <- function(refmodel, K, cvfits, verbose) {
#' cvfits = cv_fits, nterms_max = 3, nclusters = 5,
#' nclusters_pred = 10, seed = 5555)
#'
#' # Stratified K-fold CV is straightforward:
#' n_strat <- 3L
#' set.seed(692)
#' # Some example strata:
#' strat_fac <- sample(paste0("lvl", seq_len(n_strat)), size = nrow(dat_gauss),
#' replace = TRUE,
#' prob = diff(c(0, pnorm(seq_len(n_strat - 1L) - 0.5), 1)))
#' table(strat_fac)
#' # Use loo::kfold_split_stratified() to create the folds vector:
#' folds_strat <- loo::kfold_split_stratified(K = 2, x = strat_fac)
#' table(folds_strat, strat_fac)
#' # Call run_cvfun(), but this time with argument `folds` instead of `K` (here,
#' # specifying argument `seed` would not be necessary because of the set.seed()
#' # call above, but we specify it nonetheless for the sake of generality):
#' cv_fits_strat <- run_cvfun(ref, folds = folds_strat, seed = 391)
#' # Now use `cv_fits_strat` analogously to `cv_fits` from above.
#'
#' @export
run_cvfun <- function(object, ...) {
UseMethod("run_cvfun")
Expand All @@ -1529,7 +1550,7 @@ run_cvfun.default <- function(object, ...) {
#' @export
run_cvfun.refmodel <- function(object,
K = if (!inherits(object, "datafit")) 5 else 10,
seed = NA, ...) {
folds = NULL, seed = NA, ...) {
if (exists(".Random.seed", envir = .GlobalEnv)) {
rng_state_old <- get(".Random.seed", envir = .GlobalEnv)
}
Expand All @@ -1544,7 +1565,9 @@ run_cvfun.refmodel <- function(object,
refmodel <- object
stopifnot(!is.null(refmodel$cvfun))

folds <- cv_folds(refmodel$nobs, K = K)
if (is.null(folds)) {
folds <- cv_folds(refmodel$nobs, K = K)
}
if (getOption("projpred.warn_kfold_refits", TRUE)) {
cvfits <- refmodel$cvfun(folds)
} else {
Expand Down
25 changes: 24 additions & 1 deletion man/run_cvfun.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -1247,10 +1247,10 @@ if (run_cvvs) {
args_cvvs_kfold <- args_cvvs[
sapply(lapply(args_cvvs, "[[", "cv_method"), identical, "kfold")
]
tstsetups_cvvs_ref_kfold <- setNames(nm = unique(unname(
tstsetups_ref_kfold <- setNames(nm = unique(unname(
sapply(args_cvvs_kfold, "[[", "tstsetup_ref")
)))
cvfitss <- lapply(tstsetups_cvvs_ref_kfold, function(tstsetup_ref) {
cvfitss <- lapply(tstsetups_ref_kfold, function(tstsetup_ref) {
# Due to rstanarm:::kfold.stanreg() failing sometimes, we have to wrap the
# call to run_cvfun() in try():
return(try(run_cvfun(object = refmods[[tstsetup_ref]], K = K_tst,
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -2820,3 +2820,24 @@ test_that(paste(
)
}
})

# run_cvfun() -------------------------------------------------------------

test_that("argument `folds` of run_cvfun() works", {
skip_if_not(run_cvvs)
tstsetups <- names(cvfitss)
if (!run_more) {
tstsetups <- head(tstsetups, 1)
}
if (exists(".Random.seed", envir = .GlobalEnv)) {
rng_old <- get(".Random.seed", envir = .GlobalEnv)
}
for (tstsetup in tstsetups) {
set.seed(seed3_tst)
folds_sep <- cv_folds(nobsv, K = K_tst)
cvfits_sep <- run_cvfun(object = refmods[[tstsetup]], folds = folds_sep)
expect_identical(lapply(cvfits_sep, as.matrix),
lapply(cvfitss[[tstsetup]], as.matrix), info = tstsetup)
}
if (exists("rng_old")) assign(".Random.seed", rng_old, envir = .GlobalEnv)
})