-
Notifications
You must be signed in to change notification settings - Fork 47
Closed
Labels
bugan unexpected problem or unintended behavioran unexpected problem or unintended behaviornext release 🚀
Description
Underlying issue of tidymodels/yardstick#166
It seems that estimate_class() needs to somehow respect options(yardstick.event_first = FALSE) when computing the metric. It looks like it just selects the first .pred_*() column currently.
Lines 67 to 82 in fa4dc53
| prob_cols <- paste0(".pred_", lvl) | |
| if (length(prob_cols) == 2) { | |
| prob_cols <- prob_cols[1] | |
| } | |
| if (all(types == "prob")) { | |
| res <- | |
| dat %>% | |
| dplyr::group_by(!!!rlang::syms(params)) %>% | |
| metric(truth = !!sym(outcomes), !!!prob_cols) | |
| } else { | |
| res <- | |
| dat %>% | |
| dplyr::group_by(!!!rlang::syms(params)) %>% | |
| metric(truth = !!sym(outcomes), !!!prob_cols, estimate = .pred_class) | |
| } | |
| } |
suppressPackageStartupMessages({
library(tidymodels)
library(MLmetrics)
library(mlbench)
library(dplyr)
})
# Flip option!
options(yardstick.event_first = FALSE)
data("PimaIndiansDiabetes2")
set.seed(73487)
# remove duplicates for ranger
data <- PimaIndiansDiabetes2[complete.cases(PimaIndiansDiabetes2),]
xgb_recipe <- recipe(diabetes ~ ., data = data)
xgb_model <- parsnip::rand_forest() %>%
set_engine("ranger") %>%
set_mode("classification")
xgb_workflow <- workflows::workflow() %>%
workflows::add_model(xgb_model) %>%
workflows::add_recipe(xgb_recipe)
cv_folds <- rsample::vfold_cv(data, strata = diabetes, v = 2)
xgb_fit <- tune::fit_resamples(
xgb_workflow,
resamples = cv_folds,
metrics = yardstick::metric_set(pr_auc),
control = tune::control_resamples(save_pred = TRUE)
)
predictions_resample1 <- xgb_fit$.predictions[[1]]
# The 2nd level is the event!
head(predictions_resample1$diabetes)
#> [1] neg pos pos pos pos pos
#> Levels: neg pos
# So this is how you compute PR AUC
pr_auc(predictions_resample1, diabetes, .pred_pos)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 pr_auc binary 0.686
# But tune computed it by passing in `.pred_neg`!
xgb_fit$.metrics[[1]]
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 pr_auc binary 0.205
# i.e. it did this incorrectly
pr_auc(predictions_resample1, diabetes, .pred_neg)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 pr_auc binary 0.205Created on 2020-06-29 by the reprex package (v0.3.0)
Metadata
Metadata
Assignees
Labels
bugan unexpected problem or unintended behavioran unexpected problem or unintended behaviornext release 🚀