Skip to content

Commit

Permalink
Merge pull request #273 from DavisVaughan/last-fit-randomness
Browse files Browse the repository at this point in the history
Construct a "manual" rset for usage in `last_fit()`
  • Loading branch information
DavisVaughan committed Sep 14, 2020
2 parents 6161326 + b813461 commit 273be89
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 25 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Imports:
cli (>= 2.0.0),
crayon,
yardstick,
rsample,
rsample (>= 0.0.7.9000),
tidyr,
GPfit,
foreach,
Expand All @@ -48,3 +48,5 @@ LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.1
Language: en-US
Remotes:
tidymodels/rsample
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Bug Fixes

* `last_fit()` no longer accidentally adjusts the random seed (#264).

* Fixed two bugs in the acquisition function calculations.

# tune 0.1.1
Expand Down
13 changes: 3 additions & 10 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,29 +122,22 @@ last_fit.workflow <- function(object, split, ..., metrics = NULL) {
last_fit_workflow(object, split, metrics)
}

split_to_rset <- function(x) {
prop <- length(x$in_id)/nrow(x$data)
res <- rsample::mc_cv(x$data, times = 1, prop = prop)
res$splits[[1]] <- x
res
}

last_fit_workflow <- function(object, split, metrics) {
extr <- function(x)
x
ctrl <- control_resamples(save_pred = TRUE, extract = extr)
splits <- list(split)
resamples <- rsample::manual_rset(splits, ids = "train/test split")
res <-
fit_resamples(
object,
resamples = split_to_rset(split),
resamples = resamples,
metrics = metrics,
control = ctrl
)
res$id[[1]] <- "train/test split"
res$.workflow <- res$.extracts[[1]][[1]]
res$.extracts <- NULL
class(res) <- c("last_fit", class(res))
class(res) <- unique(class(res))
res
}

14 changes: 0 additions & 14 deletions tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,6 @@ test_that("recipe method", {
expect_equal(res$.predictions[[1]]$.pred, unname(test_pred))
})

test_that("split_to_rset", {

res <- tune:::split_to_rset(split)
expect_true(inherits(res, "mc_cv"))
expect_true(nrow(res) == 1)
expect_true(nrow(res) == 1)

res <- linear_reg() %>% set_engine("lm") %>% last_fit(f, split)
expect_true(is.list(res$.workflow))
expect_true(inherits(res$.workflow[[1]], "workflow"))
expect_true(is.list(res$.predictions))
expect_true(inherits(res$.predictions[[1]], "tbl_df"))
})

test_that("collect metrics of last fit", {

res <- linear_reg() %>% set_engine("lm") %>% last_fit(f, split)
Expand Down

0 comments on commit 273be89

Please sign in to comment.