Skip to content

tune:::estimate_class() doesn't respect yardstick.event_first = FALSE #240

@DavisVaughan

Description

@DavisVaughan

Underlying issue of tidymodels/yardstick#166

It seems that estimate_class() needs to somehow respect options(yardstick.event_first = FALSE) when computing the metric. It looks like it just selects the first .pred_*() column currently.

prob_cols <- paste0(".pred_", lvl)
if (length(prob_cols) == 2) {
prob_cols <- prob_cols[1]
}
if (all(types == "prob")) {
res <-
dat %>%
dplyr::group_by(!!!rlang::syms(params)) %>%
metric(truth = !!sym(outcomes), !!!prob_cols)
} else {
res <-
dat %>%
dplyr::group_by(!!!rlang::syms(params)) %>%
metric(truth = !!sym(outcomes), !!!prob_cols, estimate = .pred_class)
}
}

suppressPackageStartupMessages({
  library(tidymodels)
  library(MLmetrics)
  library(mlbench)
  library(dplyr)
})

# Flip option!
options(yardstick.event_first = FALSE) 

data("PimaIndiansDiabetes2")

set.seed(73487)

# remove duplicates for ranger
data <- PimaIndiansDiabetes2[complete.cases(PimaIndiansDiabetes2),]

xgb_recipe <- recipe(diabetes ~ ., data = data)

xgb_model <- parsnip::rand_forest() %>%
  set_engine("ranger") %>%
  set_mode("classification")

xgb_workflow <- workflows::workflow() %>%
  workflows::add_model(xgb_model) %>%
  workflows::add_recipe(xgb_recipe)

cv_folds <- rsample::vfold_cv(data, strata = diabetes, v = 2)

xgb_fit <- tune::fit_resamples(
  xgb_workflow,
  resamples = cv_folds,
  metrics = yardstick::metric_set(pr_auc),
  control = tune::control_resamples(save_pred = TRUE)
)

predictions_resample1 <- xgb_fit$.predictions[[1]]

# The 2nd level is the event!
head(predictions_resample1$diabetes)
#> [1] neg pos pos pos pos pos
#> Levels: neg pos

# So this is how you compute PR AUC
pr_auc(predictions_resample1, diabetes, .pred_pos)
#> # A tibble: 1 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 pr_auc  binary         0.686

# But tune computed it by passing in `.pred_neg`!
xgb_fit$.metrics[[1]]
#> # A tibble: 1 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 pr_auc  binary         0.205

# i.e. it did this incorrectly
pr_auc(predictions_resample1, diabetes, .pred_neg)
#> # A tibble: 1 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 pr_auc  binary         0.205

Created on 2020-06-29 by the reprex package (v0.3.0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behaviornext release 🚀

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions