Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for 3-way validation split interface #701

Merged
merged 7 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Imports:
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.1.0),
rsample (>= 1.0.0),
rsample (>= 1.1.1.9001),
tibble (>= 3.1.0),
tidyr (>= 1.2.0),
tidyselect (>= 1.1.2),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

* A method for rsample's `int_pctl()` function that will compute percentile confidence intervals on performance metrics for objects produced by `fit_resamples()`, `tune_*()`, and `last_fit()`.

* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701).

# tune 1.1.1

* Fixed a bug introduced in tune 1.1.0 in `collect_()` functions where the
Expand Down
31 changes: 30 additions & 1 deletion R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#' If `NULL`, this argument will be set to
#' [`select_best(metric)`][tune::select_best.tune_results].
#' @param verbose A logical for printing logging.
#' @param add_validation_set When the resamples embedded in `x` are a split into
#' training set and validation set, should the validation set be included in the
#' data set used to train the model. If not, only the training set is used. If
#' `NULL`, the validation set is not used for resamples originating from
#' [rsample::validation_set()] while it is used for resamples originating
#' from [rsample::validation_split()].
#' @param ... Not currently used.
#' @details
#' This function is a shortcut for the manual steps of:
Expand Down Expand Up @@ -84,6 +90,7 @@ fit_best.tune_results <- function(x,
metric = NULL,
parameters = NULL,
verbose = FALSE,
add_validation_set = NULL,
...) {
if (length(list(...))) {
cli::cli_abort(c("x" = "The `...` are not used by this function."))
Expand Down Expand Up @@ -120,7 +127,29 @@ fit_best.tune_results <- function(x,

# ----------------------------------------------------------------------------

dat <- x$splits[[1]]$data
if (inherits(x$splits[[1]], "val_split")) {
if (is.null(add_validation_set)) {
rset_info <- attr(x, "rset_info")
originate_from_3way_split <- rset_info$att$origin_3way %||% FALSE
if (originate_from_3way_split) {
add_validation_set <- FALSE
} else {
add_validation_set <- TRUE
}
}
if (add_validation_set) {
dat <- x$splits[[1]]$data
} else {
dat <- rsample::training(x$splits[[1]])
}
} else {
if (!is.null(add_validation_set)) {
rlang::warn(
"The option `add_validation_set` is being ignored because the resampling object does not include a validation set."
)
}
dat <- x$splits[[1]]$data
}
if (verbose) {
cli::cli_inform(c("i" = "Fitting using {nrow(dat)} data points..."))
}
Expand Down
51 changes: 46 additions & 5 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#' @param preprocessor A traditional model formula or a recipe created using
#' [recipes::recipe()].
#'
#' @param split An `rsplit` object created from [rsample::initial_split()].
#' @param split An `rsplit` object created from [rsample::initial_split()] or
#' [rsample::initial_validation_split()].
#'
#' @param metrics A [yardstick::metric_set()], or `NULL` to compute a standard
#' set of metrics.
Expand All @@ -25,6 +26,11 @@
#' values should be non-negative and should probably be no greater then the
#' largest event time in the training set.
#'
#' @param add_validation_set For 3-way splits into training, validation, and test
#' set via [rsample::initial_validation_split()], should the validation set be
#' included in the data set used to train the model. If not, only the training
#' set is used.
#'
#' @param ... Currently unused.
#'
#' @details
Expand Down Expand Up @@ -113,7 +119,8 @@ last_fit.model_fit <- function(object, ...) {
#' @export
#' @rdname last_fit
last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL,
control = control_last_fit(), eval_time = NULL) {
control = control_last_fit(), eval_time = NULL,
add_validation_set = FALSE) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
"To tune a model spec, you must preprocess",
Expand All @@ -133,19 +140,20 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL
wflow <- add_formula(wflow, preprocessor)
}

last_fit_workflow(wflow, split, metrics, control, eval_time)
last_fit_workflow(wflow, split, metrics, control, eval_time, add_validation_set)
}


#' @rdname last_fit
#' @export
last_fit.workflow <- function(object, split, ..., metrics = NULL,
control = control_last_fit(), eval_time = NULL) {
control = control_last_fit(), eval_time = NULL,
add_validation_set = FALSE) {
empty_ellipses(...)

control <- parsnip::condense_control(control, control_last_fit())

last_fit_workflow(object, split, metrics, control, eval_time)
last_fit_workflow(object, split, metrics, control, eval_time, add_validation_set)
}


Expand All @@ -154,6 +162,7 @@ last_fit_workflow <- function(object,
metrics,
control,
eval_time = NULL,
add_validation_set = FALSE,
...,
call = rlang::caller_env()) {
rlang::check_dots_empty()
Expand All @@ -166,6 +175,9 @@ last_fit_workflow <- function(object,
)
}

if (inherits(split, "initial_validation_split")) {
hfrick marked this conversation as resolved.
Show resolved Hide resolved
split <- prepare_validation_split(split, add_validation_set)
}
splits <- list(split)
resamples <- rsample::manual_rset(splits, ids = "train/test split")

Expand All @@ -190,3 +202,32 @@ last_fit_workflow <- function(object,
.stash_last_result(res)
res
}


prepare_validation_split <- function(split, add_validation_set){
if (add_validation_set) {
# equivalent to (unexported) rsample:::rsplit() without checks
split <- structure(
list(
data = split$data,
in_id = c(split$train_id, split$val_id),
out_id = NA
),
class = "rsplit"
)
} else {
id_train_test <- seq_len(nrow(split$data))[-sort(split$val_id)]
id_train <- match(split$train_id, id_train_test)

split <- structure(
list(
data = split$data[-sort(split$val_id), , drop = FALSE],
in_id = id_train,
out_id = NA
),
class = "rsplit"
)
}

split
}
16 changes: 15 additions & 1 deletion man/fit_best.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions man/last_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions tests/testthat/test-fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,68 @@ test_that("fit_best", {
fit_best(ames_iter_search)
)
})

test_that("fit_best() works with validation split: 3-way split", {
skip_if_not_installed("kknn")
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
initial_val_split <- rsample::initial_validation_split(ames)
val_set <- validation_set(initial_val_split)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
wflow <- workflow(f, knn_mod)

tune_res <- tune_grid(
wflow,
grid = tibble(neighbors = c(1, 5)),
resamples = val_set,
control = control_grid(save_workflow = TRUE)
) %>% suppressWarnings()
set.seed(3)
fit_on_train <- fit_best(tune_res)
pred <- predict(fit_on_train, testing(initial_val_split))

set.seed(3)
exp_fit_on_train <- nearest_neighbor(neighbors = 5) %>%
set_mode("regression") %>%
fit(f, training(initial_val_split))
exp_pred <- predict(exp_fit_on_train, testing(initial_val_split))

expect_equal(pred, exp_pred)
})

test_that("fit_best() works with validation split: 2x 2-way splits", {
skip_if_not_installed("kknn")
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
split <- rsample::initial_split(ames)
train_and_val <- training(split)
val_set <- rsample::validation_split(train_and_val)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
wflow <- workflow(f, knn_mod)

tune_res <- tune_grid(
wflow,
grid = tibble(neighbors = c(1, 5)),
resamples = val_set,
control = control_grid(save_workflow = TRUE)
)
set.seed(3)
fit_on_train_and_val <- fit_best(tune_res)
pred <- predict(fit_on_train_and_val, testing(split))

set.seed(3)
exp_fit_on_train_and_val <- nearest_neighbor(neighbors = 5) %>%
set_mode("regression") %>%
fit(f, train_and_val)
exp_pred <- predict(exp_fit_on_train_and_val, testing(split))

expect_equal(pred, exp_pred)
})
65 changes: 65 additions & 0 deletions tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,68 @@ test_that("`last_fit()` when objects need tuning", {
expect_snapshot_error(last_fit(wflow_2, split))
expect_snapshot_error(last_fit(wflow_3, split))
})

test_that("last_fit() excludes validation set for initial_validation_split objects", {
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
split <- rsample::initial_validation_split(ames)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
lm_fit <- lm(f, data = rsample::training(split))
test_pred <- predict(lm_fit, rsample::testing(split))
rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred)

res <- parsnip::linear_reg() %>%
parsnip::set_engine("lm") %>%
last_fit(f, split)

expect_equal(res, .Last.tune.result)

expect_equal(
coef(extract_fit_engine(res$.workflow[[1]])),
coef(lm_fit),
ignore_attr = TRUE
)
expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test)
expect_equal(res$.predictions[[1]]$.pred, unname(test_pred))
expect_true(res$.workflow[[1]]$trained)
expect_equal(
nrow(predict(res$.workflow[[1]], rsample::testing(split))),
nrow(rsample::testing(split))
)
})

test_that("last_fit() can include validation set for initial_validation_split objects", {
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
split <- rsample::initial_validation_split(ames)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
train_val <- rbind(rsample::training(split), rsample::validation(split))
lm_fit <- lm(f, data = train_val)
test_pred <- predict(lm_fit, rsample::testing(split))
rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred)

res <- parsnip::linear_reg() %>%
parsnip::set_engine("lm") %>%
last_fit(f, split, add_validation_set = TRUE)

expect_equal(res, .Last.tune.result)

expect_equal(
coef(extract_fit_engine(res$.workflow[[1]])),
coef(lm_fit),
ignore_attr = TRUE
)
expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test)
expect_equal(res$.predictions[[1]]$.pred, unname(test_pred))
expect_true(res$.workflow[[1]]$trained)
expect_equal(
nrow(predict(res$.workflow[[1]], rsample::testing(split))),
nrow(rsample::testing(split))
)
})
Loading