Skip to content

Commit

Permalink
Merge pull request #465 from fweber144/search_out_reuse_switch_cvmeth
Browse files Browse the repository at this point in the history
Amend #461, #463: If `validate_search = TRUE`, allow switching the CV method
  • Loading branch information
fweber144 committed Oct 23, 2023
2 parents abfd12f + d3ba3c1 commit 5d6ac95
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 69 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ If you read this from a place other than <https://mc-stan.org/projpred/news/inde

## Major changes

* Search results generated in an earlier `varsel()` or `cv_varsel()` call can now be re-used by the help of the new `varsel.vsel()` and `cv_varsel.vsel()` methods (i.e., by applying `varsel()` or `cv_varsel()` to the output of the earlier `varsel()` or `cv_varsel()` call). This can save a lot of time when re-running the predictive performance evaluation part multiple times based on the same search results. An illustration may be found in the updated main vignette (section ["Preliminary `cv_varsel()` run"](https://mc-stan.org/projpred/articles/projpred.html#preliminary-cv_varsel-run); a more general description may also be found in section ["Speed"](https://mc-stan.org/projpred/articles/projpred.html#speed)). (GitHub: #461, #463)
* Search results generated in an earlier `varsel()` or `cv_varsel()` call can now be re-used by the help of the new `varsel.vsel()` and `cv_varsel.vsel()` methods (i.e., by applying `varsel()` or `cv_varsel()` to the output of the earlier `varsel()` or `cv_varsel()` call). This can save a lot of time when re-running the predictive performance evaluation part multiple times based on the same search results. An illustration may be found in the updated main vignette (section ["Preliminary `cv_varsel()` run"](https://mc-stan.org/projpred/articles/projpred.html#preliminary-cv_varsel-run); a more general description may also be found in section ["Speed"](https://mc-stan.org/projpred/articles/projpred.html#speed)). (GitHub: #461, #463, #465)
* K-fold CV can now be combined with `validate_search = FALSE`. Related to this is an internal change which may cause LOO subsampling (see argument `nloo`) with clustered projection during the search (i.e., `1 < nclusters && nclusters < S`, where `S` denotes the number of posterior draws in the reference model) to yield slightly different results due to different internal pseudorandom number generator (PRNG) states. Furthermore, if `is.na(seed)`, then the PRNG state for code downstream of such a `cv_varsel()` call will be different due to this internal change. (GitHub: #464)

# projpred 2.7.0
Expand Down
117 changes: 61 additions & 56 deletions R/cv_varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,23 @@ cv_varsel.vsel <- function(
validate_search = object$validate_search %||% TRUE,
...
) {
rk_foldwise <- ranking(object)[["foldwise"]]
if (validate_search) {
if (!identical(cv_method, object[["cv_method"]])) {
stop("In case of `validate_search = TRUE`, cv_varsel.vsel() requires ",
"`cv_method` to be the same as `object$cv_method`.")
# When switching the CV method (which could also mean to use varsel()
# output in cv_varsel.vsel()), previous fold-wise predictor rankings
# cannot be re-used for a `validate_search = TRUE` run:
rk_foldwise <- NULL
}
if (!identical(K, object[["K"]])) {
if (identical(cv_method, "kfold") &&
identical(object[["cv_method"]], "kfold") &&
!identical(K, object[["K"]])) {
stop("In case of `validate_search = TRUE`, cv_varsel.vsel() requires ",
"`K` to be the same as `object$K`.")
}
if (!identical(cvfits, object[["cvfits"]])) {
if (identical(cv_method, "kfold") &&
identical(object[["cv_method"]], "kfold") &&
!identical(cvfits, object[["cvfits"]])) {
stop("In case of `validate_search = TRUE`, cv_varsel.vsel() requires ",
"`cvfits` to be the same as `object$cvfits`.")
}
Expand All @@ -189,8 +196,7 @@ cv_varsel.vsel <- function(
K = K,
cvfits = cvfits,
validate_search = validate_search,
search_out = list(search_path = object[["search_path"]],
ranking = ranking(object)),
search_out = nlist(search_path = object[["search_path"]], rk_foldwise),
...
))
}
Expand Down Expand Up @@ -237,27 +243,6 @@ cv_varsel.refmodel <- function(
refmodel <- object
nterms_all <- count_terms_in_formula(refmodel$formula) - 1L

# Restrictions in case of old search results which should be re-used:
if (!is.null(search_out)) {
if (cv_method %in% c("loo", "LOO") && !is.null(nloo) &&
nloo != refmodel[["nobs"]]) {
stop("Currently, `nloo == n` is needed to re-use old search results.")
}
if (validate_search) {
# We will need the fold-wise predictor rankings later:
if (is.null(search_out[["ranking"]][["foldwise"]])) {
stop("For `validate_search = TRUE`, old search results can only be ",
"re-used if the old search was performed in a fold-wise manner.")
}
# For `refit_prj = FALSE`, we would need the fold-wise submodel fits
# (along the fold-wise predictor rankings), which are currently not
# available:
if (!refit_prj) {
stop("Currently, for `validate_search = TRUE`, old search results can ",
"only be re-used if `refit_prj` is `TRUE`.")
}
}
}
# Parse arguments which also exist in varsel():
args <- parse_args_varsel(
refmodel = refmodel, method = method, refit_prj = refit_prj,
Expand All @@ -272,8 +257,9 @@ cv_varsel.refmodel <- function(
search_terms_was_null <- args$search_terms_was_null
# Parse arguments specific to cv_varsel():
args <- parse_args_cv_varsel(
refmodel = refmodel, cv_method = cv_method, K = K, cvfits = cvfits,
validate_search = validate_search, refit_prj = refit_prj
refmodel = refmodel, cv_method = cv_method, nloo = nloo, K = K,
cvfits = cvfits, validate_search = validate_search, refit_prj = refit_prj,
search_out = search_out
)
cv_method <- args$cv_method
K <- args$K
Expand Down Expand Up @@ -307,13 +293,15 @@ cv_varsel.refmodel <- function(
# Extract the fold-wise predictor rankings (to avoid passing the large
# object `search_out` itself) and coerce them to a `list` (in a row-wise
# manner) which is needed for the K-fold CV parallelization:
search_out_rk <- search_out[["ranking"]][["foldwise"]]
n_folds <- nrow(search_out_rk)
search_out_rk <- lapply(seq_len(n_folds), function(row_idx) {
search_out_rk[row_idx, ]
})
search_out_rks <- search_out[["rk_foldwise"]]
if (!is.null(search_out_rks)) {
n_folds <- nrow(search_out_rks)
search_out_rks <- lapply(seq_len(n_folds), function(row_idx) {
search_out_rks[row_idx, ]
})
}
} else {
search_out_rk <- NULL
search_out_rks <- NULL
}

if (cv_method == "LOO") {
Expand All @@ -332,7 +320,7 @@ cv_varsel.refmodel <- function(
},
search_terms = search_terms,
search_terms_was_null = search_terms_was_null,
search_out_rk = search_out_rk, parallel = parallel, ...
search_out_rks = search_out_rks, parallel = parallel, ...
)
} else if (cv_method == "kfold") {
sel_cv <- kfold_varsel(
Expand All @@ -350,7 +338,7 @@ cv_varsel.refmodel <- function(
# `refit_prj = FALSE`, so all that we need is element `solution_terms`:
search_path_fulldata["solution_terms"]
},
search_terms = search_terms, search_out_rk = search_out_rk,
search_terms = search_terms, search_out_rks = search_out_rks,
parallel = parallel, ...
)
}
Expand Down Expand Up @@ -422,8 +410,8 @@ cv_varsel.refmodel <- function(
# @param validate_search See argument `validate_search` of cv_varsel().
#
# @return A list with the processed elements `cv_method`, `K`, and `cvfits`.
parse_args_cv_varsel <- function(refmodel, cv_method, K, cvfits,
validate_search, refit_prj) {
parse_args_cv_varsel <- function(refmodel, cv_method, nloo, K, cvfits,
validate_search, refit_prj, search_out) {
stopifnot(!is.null(cv_method))
if (cv_method == "loo") {
cv_method <- toupper(cv_method)
Expand Down Expand Up @@ -467,6 +455,23 @@ parse_args_cv_varsel <- function(refmodel, cv_method, K, cvfits,
}
}

# Restrictions in case of previous search results which should be re-used:
if (!is.null(search_out)) {
if (cv_method == "LOO" && !is.null(nloo) && nloo != refmodel[["nobs"]]) {
# It would be hard (if not impossible) to ensure that the same PSIS-LOO CV
# folds (i.e., observations) are subsampled:
stop("Subsampled PSIS-LOO CV (see argument `nloo`) cannot be combined ",
"with the re-use of previous search results.")
}
if (validate_search && !refit_prj) {
# For `validate_search = TRUE` and `refit_prj = FALSE`, we would need the
# fold-wise submodel fits (along the fold-wise predictor rankings), which
# are currently not available:
stop("If `validate_search = TRUE`, then `refit_prj = FALSE` cannot be ",
"combined with the re-use of previous search results.")
}
}

return(nlist(cv_method, K, cvfits))
}

Expand All @@ -481,7 +486,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
nclusters, ndraws_pred, nclusters_pred, refit_prj,
penalty, verbose, opt, nloo, validate_search,
search_path_fulldata, search_terms,
search_terms_was_null, search_out_rk, parallel, ...) {
search_terms_was_null, search_out_rks, parallel, ...) {
## Pre-processing ---------------------------------------------------------

has_grp <- formula_contains_group_terms(refmodel$formula)
Expand Down Expand Up @@ -870,8 +875,8 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
} else {
## Case `validate_search = TRUE` ------------------------------------------

search_out_rk_was_null <- is.null(search_out_rk)
if (search_out_rk_was_null) {
search_out_rks_was_null <- is.null(search_out_rks)
if (search_out_rks_was_null) {
cl_sel <- get_refdist(refmodel, ndraws = ndraws, nclusters = nclusters)$cl
}
if (refit_prj) {
Expand All @@ -881,7 +886,7 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,

if (verbose) {
verb_txt_start <- "-----\nRunning "
if (!search_out_rk_was_null) {
if (!search_out_rks_was_null) {
verb_txt_mid <- ""
} else {
verb_txt_mid <- "the search and "
Expand All @@ -900,8 +905,8 @@ loo_varsel <- function(refmodel, method, nterms_max, ndraws,
# *reweighted* fitted response values from the reference model act as
# artifical response values in the projection (or L1-penalized
# projection)):
if (!search_out_rk_was_null) {
search_path <- list(solution_terms = search_out_rk[[run_index]])
if (!search_out_rks_was_null) {
search_path <- list(solution_terms = search_out_rks[[run_index]])
} else {
search_path <- select(
refmodel = refmodel, ndraws = ndraws, nclusters = nclusters,
Expand Down Expand Up @@ -1173,23 +1178,23 @@ warn_pareto <- function(n07, n05, warn_txt_start, warn_txt_mid_common,
# Needed to avoid a NOTE in `R CMD check`:
if (getRversion() >= package_version("2.15.1")) {
utils::globalVariables("list_cv_k")
utils::globalVariables("search_out_rk_k")
utils::globalVariables("search_out_rks_k")
}

kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
ndraws_pred, nclusters_pred, refit_prj, penalty,
verbose, opt, K, cvfits, validate_search,
search_path_fulldata, search_terms, search_out_rk,
search_path_fulldata, search_terms, search_out_rks,
parallel, ...) {
# Fetch the K reference model fits (or fit them now if not already done) and
# create objects of class `refmodel` from them (and also store the `omitted`
# indices):
list_cv <- get_kfold(refmodel, K = K, cvfits = cvfits, verbose = verbose)
K <- length(list_cv)

search_out_rk_was_null <- is.null(search_out_rk)
if (search_out_rk_was_null) {
search_out_rk <- replicate(K, NULL, simplify = FALSE)
search_out_rks_was_null <- is.null(search_out_rks)
if (search_out_rks_was_null) {
search_out_rks <- replicate(K, NULL, simplify = FALSE)
}

if (refmodel$family$for_latent) {
Expand All @@ -1204,7 +1209,7 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,

if (verbose) {
verb_txt_start <- "-----\nRunning "
if (!search_out_rk_was_null || !validate_search) {
if (!search_out_rks_was_null || !validate_search) {
verb_txt_mid <- ""
} else {
verb_txt_mid <- "the search and "
Expand All @@ -1220,7 +1225,7 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
# Run the search for the current fold:
if (!validate_search) {
search_path <- search_path_fulldata
} else if (!search_out_rk_was_null) {
} else if (!search_out_rks_was_null) {
search_path <- list(solution_terms = rk)
} else {
search_path <- select(
Expand Down Expand Up @@ -1272,7 +1277,7 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
if (verbose) {
on.exit(utils::setTxtProgressBar(pb, k))
}
one_fold(fold = list_cv[[k]], rk = search_out_rk[[k]], ...)
one_fold(fold = list_cv[[k]], rk = search_out_rks[[k]], ...)
})
if (verbose) {
close(pb)
Expand All @@ -1289,11 +1294,11 @@ kfold_varsel <- function(refmodel, method, nterms_max, ndraws, nclusters,
`%do_projpred%` <- doRNG::`%dorng%`
res_cv <- foreach::foreach(
list_cv_k = list_cv,
search_out_rk_k = search_out_rk,
search_out_rks_k = search_out_rks,
.export = c("one_fold", "dot_args"),
.noexport = c("list_cv", "search_out_rk")
.noexport = c("list_cv", "search_out_rks")
) %do_projpred% {
do.call(one_fold, c(list(fold = list_cv_k, rk = search_out_rk_k,
do.call(one_fold, c(list(fold = list_cv_k, rk = search_out_rks_k,
verbose_search = FALSE),
dot_args))
}
Expand Down
3 changes: 1 addition & 2 deletions R/varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ varsel.vsel <- function(object, ...) {
thresh = object[["args_search"]][["thresh"]],
penalty = object[["args_search"]][["penalty"]],
search_terms = object[["args_search"]][["search_terms"]],
search_out = list(search_path = object[["search_path"]],
ranking = ranking(object)),
search_out = list(search_path = object[["search_path"]]),
...
))
}
Expand Down
8 changes: 7 additions & 1 deletion tests/testthat/helpers/testers.R
Original file line number Diff line number Diff line change
Expand Up @@ -2144,8 +2144,14 @@ vsel_tester <- function(
c(n_folds, solterms_len_expected),
info = info_str)
# We need the addition of `NA_character_` because of subsampled PSIS-LOO CV:
soltrms_cv <- unique(as.vector(vs$solution_terms_cv))
for (soltrms_cv_plus in grep("\\+", soltrms_cv, value = TRUE)) {
soltrms_cv <- setdiff(soltrms_cv, soltrms_cv_plus)
soltrms_cv <- c(soltrms_cv,
labels(terms(as.formula(paste(". ~", soltrms_cv_plus)))))
}
expect_true(
all(vs$solution_terms_cv %in% c(trms_universe_split, NA_character_)),
all(soltrms_cv %in% c(trms_universe_split, NA_character_)),
info = info_str
)
} else {
Expand Down
Loading

0 comments on commit 5d6ac95

Please sign in to comment.