Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

control_last_fit() #399

Closed
topepo opened this issue Aug 3, 2021 · 4 comments · Fixed by #481
Closed

control_last_fit() #399

topepo opened this issue Aug 3, 2021 · 4 comments · Fixed by #481
Assignees
Labels
feature a feature request or enhancement

Comments

@topepo
Copy link
Member

topepo commented Aug 3, 2021

The other tune functions have control arguments/functions. That's where we set options like event_level.

See this SO post.

@DavisVaughan
Copy link
Member

We make the control object internally with a few specific arguments set:

tune/R/last_fit.R

Lines 104 to 106 in a1d276d

extr <- function(x)
x
control <- control_resamples(save_pred = TRUE, extract = extr)

Maybe control_last_fit() just wouldn't have those arguments

@juliasilge juliasilge added the feature a feature request or enhancement label Mar 3, 2022
@mattwarkentin
Copy link
Contributor

@topepo @juliasilge Would you be open to a PR for this? I think the inability to set the event_level for last_fit() is quite problematic. Here is an example, which I think represents a fairly standard modeling approach, where last_fit() confusingly estimates the AUC for the correct direction but the PR for the opposite direction. This could be quite misleading if the user doesn't catch it.

last_fit() suggests the PR is 0.9, meanwhile the actual PR (when event_level is set properly) is 0.6.

library(tidyverse)
library(tidymodels)
set.seed(12345)

data <-
  tibble(
    y = factor(c(rep(0, 8000), rep(1, 2000)), levels = 0:1),
    x1 = rnorm(10000),
    x2 = rnorm(10000)
  )

tt_split <- initial_split(data, strata = y)
train <- training(tt_split)
test <- testing(tt_split)

wflow <-
  workflow(
    preprocessor = y ~ .,
    spec = logistic_reg(engine = 'glmnet', penalty = tune(), mixture = 1)
  )
  
res <-
  tune_grid(
    object = wflow,
    resamples = vfold_cv(train, strata = y),
    grid = 10,
    metrics = metric_set(roc_auc, pr_auc),
    control = control_grid(event_level = 'second')
  )

show_best(res, 'roc_auc', n = 1)
#> # A tibble: 1 × 7
#>   penalty .metric .estimator  mean     n std_err .config              
#>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#> 1  0.0224 roc_auc binary       0.5    10       0 Preprocessor1_Model09

show_best(res, 'pr_auc', n = 1)
#> # A tibble: 1 × 7
#>   penalty .metric .estimator  mean     n std_err .config              
#>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#> 1  0.0224 pr_auc  binary       0.6    10       0 Preprocessor1_Model09

final_fit <- 
  last_fit(
    finalize_workflow(wflow, select_best(res, metric = 'roc_auc')),
    split = tt_split,
    metrics = metric_set(roc_auc, pr_auc)
  )

final_model <- extract_workflow(final_fit)

collect_metrics(final_fit)
#> # A tibble: 2 × 4
#>   .metric .estimator .estimate .config             
#>   <chr>   <chr>          <dbl> <chr>               
#> 1 roc_auc binary           0.5 Preprocessor1_Model1
#> 2 pr_auc  binary           0.9 Preprocessor1_Model1

preds <- bind_cols(
  test, predict(final_model, new_data = test, type = 'prob')
)

pr_auc(preds, y, .pred_1, event_level = 'second')
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 pr_auc  binary           0.6
roc_auc(preds, y, .pred_1, event_level = 'second')
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary           0.5

@mattwarkentin
Copy link
Contributor

I think a simple fix is:

control_last_fit <- function(
    verbose = FALSE,
    pkgs = NULL,
    event_level = "first"
) {
  extr <- function(x) x
  control_resamples(
    verbose = verbose,
    pkgs = pkgs,
    event_level = event_level,
    save_pred = TRUE,
    save_workflow = FALSE,
    extract = extr
  )
}
last_fit_workflow <- function(object, split, metrics, control = control_last_fit()) {
  splits <- list(split)
  resamples <- rsample::manual_rset(splits, ids = "train/test split")

  rng <- FALSE

  res <- resample_workflow(
    workflow = object,
    resamples = resamples,
    metrics = metrics,
    control = control,
    rng = rng
  )

  res$.workflow <- res$.extracts[[1]][[1]]
  res$.extracts <- NULL
  class(res) <- c("last_fit", class(res))
  class(res) <- unique(class(res))
  res
}

Then it would just be a matter or modify the argument signature for last_fit.model_spec, last_fit.workflow to accept the control argument and pass it forward.

@github-actions
Copy link

github-actions bot commented May 4, 2022

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators May 4, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
feature a feature request or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants