Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better detection of groups/tuning parameters #92

Open
topepo opened this issue Mar 21, 2023 · 2 comments
Open

better detection of groups/tuning parameters #92

topepo opened this issue Mar 21, 2023 · 2 comments
Assignees
Labels
bug an unexpected problem or unintended behavior

Comments

@topepo
Copy link
Member

topepo commented Mar 21, 2023

When using data frames generated from the tune_*() functions, we silently produce a single plot/analysis if the user doesn't correctly specify what they want.

We should detect this (when there is more than one config) and produce a meaningful error.

Also, the plot functions have a group argument and the estimation functions require group_by(). That's confusing.

Example:

library(tidymodels)
library(probably)
#> 
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#> 
#>     as.factor, as.ordered
library(bonsai)
tidymodels_prefer()
theme_set(theme_bw())
options(pillar.advice = FALSE, pillar.min_title_chars = Inf)
set.seed(1345)
cls_train <- sim_classification(1000)
cls_test  <- sim_classification( 500)
cls_calib <- sim_classification( 500)

set.seed(7378)
cls_rs <- vfold_cv(cls_train)
lgb_spec <- boost_tree() %>% set_mode("classification") %>% set_engine("lightgbm")
cls_metrics <- metric_set(brier_class, roc_auc)

set.seed(6929)
lgb_tune_res <-
  boost_tree(min_n = tune()) %>%
  set_mode("classification") %>%
  set_engine("lightgbm") %>%
  tune_grid(
    class ~ .,
    resamples = cls_rs,
    control = control_resamples(save_pred = TRUE),
    metrics = cls_metrics,
    grid = tibble(min_n = c(2, 50))
  )
df_pred_res  <- lgb_res %>% collect_predictions()
#> Error in collect_predictions(.): object 'lgb_res' not found
df_pred_tune_res  <- lgb_tune_res %>% collect_predictions()

df_new <- df_pred_res[1:5,]
#> Error in eval(expr, envir, enclos): object 'df_pred_res' not found
df_tune_new <- df_pred_tune_res %>% dplyr::slice(1:5, .by = .config)
# Plotting issues

# This produces 1 plot; should be two
df_pred_tune_res %>%
  cal_plot_windowed(truth = class, estimate = .pred_class_1,
                    window_size = 0.1, step_size = 0.025)

# Using `group` makes two plots
df_pred_tune_res %>%
  cal_plot_windowed(truth = class, estimate = .pred_class_1, group = .config,
                    window_size = 0.1, step_size = 0.025)

# Estimation issues

# Should have two groups
df_pred_tune_res %>%
  cal_estimate_logistic(truth = class)
#> 
#> ── Probability Calibration
#> Method: Logistic Spline
#> Type: Binary
#> Source class: Data Frame
#> Data points: 2,000
#> Truth variable: `class`
#> Estimate variables:
#> `.pred_class_1` ==> class_1
#> `.pred_class_2` ==> class_2

# Has two groups via a different "by" mechanism:
df_pred_tune_res %>%
  group_by(.config) %>% 
  cal_estimate_logistic(truth = class)
#> 
#> ── Probability Calibration
#> Method: Logistic Spline
#> Type: Binary
#> Source class: Data Frame
#> Data points: 2,000, split in 2 groups
#> Truth variable: `class`
#> Estimate variables:
#> `.pred_class_1` ==> class_1
#> `.pred_class_2` ==> class_2

Created on 2023-03-21 by the reprex package (v2.0.1)

@topepo topepo added the bug an unexpected problem or unintended behavior label Mar 21, 2023
topepo added a commit that referenced this issue Mar 28, 2023
topepo added a commit that referenced this issue Apr 28, 2023
* plot changes for #92

* add groups to estimate function

* Revert "add groups to estimate function"

This reverts commit 558fa43.

* tuning the quosure code and update tests

* update with new pillar
@topepo topepo self-assigned this May 3, 2023
@topepo
Copy link
Member Author

topepo commented May 3, 2023

Currently, if there are 2+ .config values, we get this error:

Error: ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data.

@topepo
Copy link
Member Author

topepo commented May 3, 2023

Now, the plan is, for _estimate_ functions:

  • look for .config and if there are >1 values, internally group_by(.config), and report warning
  • add a .by argument to group on any column.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

2 participants