-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
control_last_fit() #399
Comments
We make the control object internally with a few specific arguments set: Lines 104 to 106 in a1d276d
Maybe |
@topepo @juliasilge Would you be open to a PR for this? I think the inability to set the
library(tidyverse)
library(tidymodels)
set.seed(12345)
data <-
tibble(
y = factor(c(rep(0, 8000), rep(1, 2000)), levels = 0:1),
x1 = rnorm(10000),
x2 = rnorm(10000)
)
tt_split <- initial_split(data, strata = y)
train <- training(tt_split)
test <- testing(tt_split)
wflow <-
workflow(
preprocessor = y ~ .,
spec = logistic_reg(engine = 'glmnet', penalty = tune(), mixture = 1)
)
res <-
tune_grid(
object = wflow,
resamples = vfold_cv(train, strata = y),
grid = 10,
metrics = metric_set(roc_auc, pr_auc),
control = control_grid(event_level = 'second')
)
show_best(res, 'roc_auc', n = 1)
#> # A tibble: 1 × 7
#> penalty .metric .estimator mean n std_err .config
#> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.0224 roc_auc binary 0.5 10 0 Preprocessor1_Model09
show_best(res, 'pr_auc', n = 1)
#> # A tibble: 1 × 7
#> penalty .metric .estimator mean n std_err .config
#> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.0224 pr_auc binary 0.6 10 0 Preprocessor1_Model09
final_fit <-
last_fit(
finalize_workflow(wflow, select_best(res, metric = 'roc_auc')),
split = tt_split,
metrics = metric_set(roc_auc, pr_auc)
)
final_model <- extract_workflow(final_fit)
collect_metrics(final_fit)
#> # A tibble: 2 × 4
#> .metric .estimator .estimate .config
#> <chr> <chr> <dbl> <chr>
#> 1 roc_auc binary 0.5 Preprocessor1_Model1
#> 2 pr_auc binary 0.9 Preprocessor1_Model1
preds <- bind_cols(
test, predict(final_model, new_data = test, type = 'prob')
)
pr_auc(preds, y, .pred_1, event_level = 'second')
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 pr_auc binary 0.6
roc_auc(preds, y, .pred_1, event_level = 'second')
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.5 |
I think a simple fix is: control_last_fit <- function(
verbose = FALSE,
pkgs = NULL,
event_level = "first"
) {
extr <- function(x) x
control_resamples(
verbose = verbose,
pkgs = pkgs,
event_level = event_level,
save_pred = TRUE,
save_workflow = FALSE,
extract = extr
)
} last_fit_workflow <- function(object, split, metrics, control = control_last_fit()) {
splits <- list(split)
resamples <- rsample::manual_rset(splits, ids = "train/test split")
rng <- FALSE
res <- resample_workflow(
workflow = object,
resamples = resamples,
metrics = metrics,
control = control,
rng = rng
)
res$.workflow <- res$.extracts[[1]][[1]]
res$.extracts <- NULL
class(res) <- c("last_fit", class(res))
class(res) <- unique(class(res))
res
} Then it would just be a matter or modify the argument signature for |
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
The other
tune
functions have control arguments/functions. That's where we set options likeevent_level
.See this SO post.
The text was updated successfully, but these errors were encountered: