Skip to content

XGBoost trains second level, while tidymodels trains first. Is this a problem? #249

@jcpsantiago

Description

@jcpsantiago

The problem

I'm concerned with the validity of results from training XGBoost using tidymodels.
tidymodels uses the first level of a factor as the default positive class, while XGBoost defaults to the second (although I couldn't find this in the docs; didn't check source code). Using options(yardstick.event_first = FALSE) is not a viable solution, because of #240 and tidymodels/yardstick#166, meaning it's possible select_best() won't select the best set of parameters based on the PR-AUC.

The problem arises when butcher::butcher() the model fit for deployment.
Now, the underlying XGBoost model is predicting the negative model class.
Yes, I can simply 1 - prediction to get the probability of the positive class, but I'm not sure if that is reliable.

Was XGBoost in fact trained to predict the negative class? If so, how reliable are the metrics generated by tidymodels and the tuning process and the final predictions?

Also note that, on my actual data with severe class imbalance, the evaluation log from XGBoost has a PR-AUC of 98 on the first boosting iteration, which also tells me it's targeting the wrong class.

Posting here because #240, but this may not be a tune bug (or an issue at all, for that matter).

Reproducible example

suppressPackageStartupMessages({
  library(tidymodels)
  library(MLmetrics)
  library(mlbench)
  library(dplyr)
  library(forcats)
  library(butcher)
})

data("PimaIndiansDiabetes2")

set.seed(73487)

# remove duplicates 
data <- PimaIndiansDiabetes2[complete.cases(PimaIndiansDiabetes2),] %>% 
  # flip levels to avoid issue 240
  mutate(diabetes = forcats::fct_relevel(diabetes, c("pos", "neg")))

xgb_recipe <- recipe(diabetes ~ ., data = data)

xgb_model <- parsnip::boost_tree(
  trees = tune::tune(),
  tree_depth = tune::tune(), min_n = tune::tune(),
  loss_reduction = tune::tune(),
  sample_size = tune::tune(), mtry = tune::tune(),
  learn_rate = tune::tune(),
) %>%
  parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr")) %>%
  parsnip::set_mode(., mode = "classification")

xgb_grid <- dials::grid_max_entropy(
  dials::trees(c(200, 1500)),
  dials::min_n(),
  dials::tree_depth(c(3, 10)),
  dials::loss_reduction(),
  sample_size = dials::sample_prop(c(0.5, 1)),
  dials::finalize(dials::mtry(), recipes::prep(xgb_recipe) %>% recipes::juice()),
  dials::learn_rate(c(0.01, 0.3), trans = NULL),
  size = 10
)

xgb_workflow <- workflows::workflow() %>%
  workflows::add_model(xgb_model) %>%
  workflows::add_recipe(xgb_recipe)

cv_folds <- rsample::vfold_cv(
  data,
  strata = diabetes,
  v = 3
)

xgb_tuned_fit <- tune::tune_grid(
  xgb_workflow,
  resamples = cv_folds,
  grid = xgb_grid,
  metrics = yardstick::metric_set(pr_auc),
  control = tune::control_grid(save_pred = TRUE, verbose = TRUE)
)
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> i Fold1: recipe
#> ✓ Fold1: recipe
#> i Fold1: model  1/10
#> ✓ Fold1: model  1/10
#> i Fold1: model  1/10 (predictions)
#> i Fold1: model  2/10
#> ✓ Fold1: model  2/10
#> i Fold1: model  2/10 (predictions)
#> i Fold1: model  3/10
#> ✓ Fold1: model  3/10
#> i Fold1: model  3/10 (predictions)
#> i Fold1: model  4/10
#> ✓ Fold1: model  4/10
#> i Fold1: model  4/10 (predictions)
#> i Fold1: model  5/10
#> ✓ Fold1: model  5/10
#> i Fold1: model  5/10 (predictions)
#> i Fold1: model  6/10
#> ✓ Fold1: model  6/10
#> i Fold1: model  6/10 (predictions)
#> i Fold1: model  7/10
#> ✓ Fold1: model  7/10
#> i Fold1: model  7/10 (predictions)
#> i Fold1: model  8/10
#> ✓ Fold1: model  8/10
#> i Fold1: model  8/10 (predictions)
#> i Fold1: model  9/10
#> ✓ Fold1: model  9/10
#> i Fold1: model  9/10 (predictions)
#> i Fold1: model 10/10
#> ✓ Fold1: model 10/10
#> i Fold1: model 10/10 (predictions)
#> i Fold2: recipe
#> ✓ Fold2: recipe
#> i Fold2: model  1/10
#> ✓ Fold2: model  1/10
#> i Fold2: model  1/10 (predictions)
#> i Fold2: model  2/10
#> ✓ Fold2: model  2/10
#> i Fold2: model  2/10 (predictions)
#> i Fold2: model  3/10
#> ✓ Fold2: model  3/10
#> i Fold2: model  3/10 (predictions)
#> i Fold2: model  4/10
#> ✓ Fold2: model  4/10
#> i Fold2: model  4/10 (predictions)
#> i Fold2: model  5/10
#> ✓ Fold2: model  5/10
#> i Fold2: model  5/10 (predictions)
#> i Fold2: model  6/10
#> ✓ Fold2: model  6/10
#> i Fold2: model  6/10 (predictions)
#> i Fold2: model  7/10
#> ✓ Fold2: model  7/10
#> i Fold2: model  7/10 (predictions)
#> i Fold2: model  8/10
#> ✓ Fold2: model  8/10
#> i Fold2: model  8/10 (predictions)
#> i Fold2: model  9/10
#> ✓ Fold2: model  9/10
#> i Fold2: model  9/10 (predictions)
#> i Fold2: model 10/10
#> ✓ Fold2: model 10/10
#> i Fold2: model 10/10 (predictions)
#> i Fold3: recipe
#> ✓ Fold3: recipe
#> i Fold3: model  1/10
#> ✓ Fold3: model  1/10
#> i Fold3: model  1/10 (predictions)
#> i Fold3: model  2/10
#> ✓ Fold3: model  2/10
#> i Fold3: model  2/10 (predictions)
#> i Fold3: model  3/10
#> ✓ Fold3: model  3/10
#> i Fold3: model  3/10 (predictions)
#> i Fold3: model  4/10
#> ✓ Fold3: model  4/10
#> i Fold3: model  4/10 (predictions)
#> i Fold3: model  5/10
#> ✓ Fold3: model  5/10
#> i Fold3: model  5/10 (predictions)
#> i Fold3: model  6/10
#> ✓ Fold3: model  6/10
#> i Fold3: model  6/10 (predictions)
#> i Fold3: model  7/10
#> ✓ Fold3: model  7/10
#> i Fold3: model  7/10 (predictions)
#> i Fold3: model  8/10
#> ✓ Fold3: model  8/10
#> i Fold3: model  8/10 (predictions)
#> i Fold3: model  9/10
#> ✓ Fold3: model  9/10
#> i Fold3: model  9/10 (predictions)
#> i Fold3: model 10/10
#> ✓ Fold3: model 10/10
#> i Fold3: model 10/10 (predictions)

tune::collect_metrics(xgb_tuned_fit) %>% 
  select(mean) %>% 
  head(2)
#> # A tibble: 2 x 1
#>    mean
#>   <dbl>
#> 1 0.666
#> 2 0.665

best_model_params <- tune::select_best(xgb_tuned_fit, "pr_auc")

final_xgb <- tune::finalize_workflow(
  xgb_workflow,
  best_model_params
)

final_fit <- final_xgb %>%
  fit(data = data)

xgb_fit <- pull_workflow_fit(final_fit) %>% 
  butcher::butcher()

n <- xgb_fit$feature_names

predict(final_fit, data, type = "prob") %>% 
  head()
#> # A tibble: 6 x 2
#>   .pred_pos .pred_neg
#>       <dbl>     <dbl>
#> 1  0.000576   0.999  
#> 2  0.997      0.00278
#> 3  0.960      0.0405 
#> 4  0.994      0.00607
#> 5  0.994      0.00624
#> 6  0.989      0.0112

# xgboost predicts second class "neg"
predict(xgb_fit, data.matrix(data %>% select(n))) %>% 
  head()
#> Note: Using an external vector in selections is ambiguous.
#> ℹ Use `all_of(n)` instead of `n` to silence this message.
#> ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
#> This message is displayed once per session.
#> [1] 0.999423981 0.002777288 0.040465232 0.006066449 0.006243575 0.011152640

Created on 2020-07-09 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.0 (2020-04-24)
#>  os       macOS Catalina 10.15.5      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2020-07-09                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package       * version    date       lib source                             
#>  assertthat      0.2.1      2019-03-21 [1] CRAN (R 4.0.0)                     
#>  backports       1.1.7      2020-05-13 [1] CRAN (R 4.0.0)                     
#>  base64enc       0.1-3      2015-07-28 [1] CRAN (R 4.0.0)                     
#>  bayesplot       1.7.2      2020-05-28 [1] CRAN (R 4.0.0)                     
#>  boot            1.3-25     2020-04-26 [1] CRAN (R 4.0.0)                     
#>  broom         * 0.5.6      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  butcher       * 0.1.2      2020-01-23 [1] CRAN (R 4.0.0)                     
#>  callr           3.4.3      2020-03-28 [1] CRAN (R 4.0.0)                     
#>  class           7.3-17     2020-04-26 [1] CRAN (R 4.0.0)                     
#>  cli             2.0.2      2020-02-28 [1] CRAN (R 4.0.0)                     
#>  codetools       0.2-16     2018-12-24 [1] CRAN (R 4.0.0)                     
#>  colorspace      1.4-1      2019-03-18 [1] CRAN (R 4.0.0)                     
#>  colourpicker    1.0        2017-09-27 [1] CRAN (R 4.0.0)                     
#>  crayon          1.3.4      2017-09-16 [1] CRAN (R 4.0.0)                     
#>  crosstalk       1.1.0.1    2020-03-13 [1] CRAN (R 4.0.0)                     
#>  data.table      1.12.8     2019-12-09 [1] CRAN (R 4.0.0)                     
#>  desc            1.2.0      2018-05-01 [1] CRAN (R 4.0.0)                     
#>  devtools        2.3.0      2020-04-10 [1] CRAN (R 4.0.0)                     
#>  dials         * 0.0.7      2020-06-10 [1] CRAN (R 4.0.0)                     
#>  DiceDesign      1.8-1      2019-07-31 [1] CRAN (R 4.0.0)                     
#>  digest          0.6.25     2020-02-23 [1] CRAN (R 4.0.0)                     
#>  dplyr         * 1.0.0      2020-05-29 [1] CRAN (R 4.0.0)                     
#>  DT              0.13       2020-03-23 [1] CRAN (R 4.0.0)                     
#>  dygraphs        1.1.1.6    2018-07-11 [1] CRAN (R 4.0.0)                     
#>  ellipsis        0.3.1      2020-05-15 [1] CRAN (R 4.0.0)                     
#>  evaluate        0.14       2019-05-28 [1] CRAN (R 4.0.0)                     
#>  fansi           0.4.1      2020-01-08 [1] CRAN (R 4.0.0)                     
#>  fastmap         1.0.1      2019-10-08 [1] CRAN (R 4.0.0)                     
#>  forcats       * 0.5.0      2020-03-01 [1] CRAN (R 4.0.0)                     
#>  foreach         1.5.0      2020-03-30 [1] CRAN (R 4.0.0)                     
#>  fs              1.4.1      2020-04-04 [1] CRAN (R 4.0.0)                     
#>  furrr           0.1.0      2018-05-16 [1] CRAN (R 4.0.0)                     
#>  future          1.17.0     2020-04-18 [1] CRAN (R 4.0.0)                     
#>  generics        0.0.2      2018-11-29 [1] CRAN (R 4.0.0)                     
#>  ggplot2       * 3.3.2      2020-06-19 [1] CRAN (R 4.0.0)                     
#>  ggridges        0.5.2      2020-01-12 [1] CRAN (R 4.0.0)                     
#>  globals         0.12.5     2019-12-07 [1] CRAN (R 4.0.0)                     
#>  glue            1.4.1      2020-05-13 [1] CRAN (R 4.0.0)                     
#>  gower           0.2.1      2019-05-14 [1] CRAN (R 4.0.0)                     
#>  GPfit           1.0-8      2019-02-08 [1] CRAN (R 4.0.0)                     
#>  gridExtra       2.3        2017-09-09 [1] CRAN (R 4.0.0)                     
#>  gtable          0.3.0      2019-03-25 [1] CRAN (R 4.0.0)                     
#>  gtools          3.8.2      2020-03-31 [1] CRAN (R 4.0.0)                     
#>  hardhat         0.1.3      2020-05-20 [1] CRAN (R 4.0.0)                     
#>  highr           0.8        2019-03-20 [1] CRAN (R 4.0.0)                     
#>  htmltools       0.4.0      2019-10-04 [1] CRAN (R 4.0.0)                     
#>  htmlwidgets     1.5.1      2019-10-08 [1] CRAN (R 4.0.0)                     
#>  httpuv          1.5.4      2020-06-06 [1] CRAN (R 4.0.0)                     
#>  igraph          1.2.5      2020-03-19 [1] CRAN (R 4.0.0)                     
#>  infer         * 0.5.2      2020-06-14 [1] CRAN (R 4.0.0)                     
#>  inline          0.3.15     2018-05-18 [1] CRAN (R 4.0.0)                     
#>  ipred           0.9-9      2019-04-28 [1] CRAN (R 4.0.0)                     
#>  iterators       1.0.12     2019-07-26 [1] CRAN (R 4.0.0)                     
#>  janeaustenr     0.1.5      2017-06-10 [1] CRAN (R 4.0.0)                     
#>  knitr           1.28       2020-02-06 [1] CRAN (R 4.0.0)                     
#>  later           1.1.0.1    2020-06-05 [1] CRAN (R 4.0.0)                     
#>  lattice         0.20-41    2020-04-02 [1] CRAN (R 4.0.0)                     
#>  lava            1.6.7      2020-03-05 [1] CRAN (R 4.0.0)                     
#>  lhs             1.0.2      2020-04-13 [1] CRAN (R 4.0.0)                     
#>  lifecycle       0.2.0      2020-03-06 [1] CRAN (R 4.0.0)                     
#>  listenv         0.8.0      2019-12-05 [1] CRAN (R 4.0.0)                     
#>  lme4            1.1-23     2020-04-07 [1] CRAN (R 4.0.0)                     
#>  loo             2.2.0      2019-12-19 [1] CRAN (R 4.0.0)                     
#>  lubridate       1.7.9      2020-06-08 [1] CRAN (R 4.0.0)                     
#>  magrittr        1.5        2014-11-22 [1] CRAN (R 4.0.0)                     
#>  markdown        1.1        2019-08-07 [1] CRAN (R 4.0.0)                     
#>  MASS            7.3-51.6   2020-04-26 [1] CRAN (R 4.0.0)                     
#>  Matrix          1.2-18     2019-11-27 [1] CRAN (R 4.0.0)                     
#>  matrixStats     0.56.0     2020-03-13 [1] CRAN (R 4.0.0)                     
#>  memoise         1.1.0      2017-04-21 [1] CRAN (R 4.0.0)                     
#>  mime            0.9        2020-02-04 [1] CRAN (R 4.0.0)                     
#>  miniUI          0.1.1.1    2018-05-18 [1] CRAN (R 4.0.0)                     
#>  minqa           1.2.4      2014-10-09 [1] CRAN (R 4.0.0)                     
#>  mlbench       * 2.1-1      2012-07-10 [1] CRAN (R 4.0.0)                     
#>  MLmetrics     * 1.1.1      2016-05-13 [1] CRAN (R 4.0.0)                     
#>  munsell         0.5.0      2018-06-12 [1] CRAN (R 4.0.0)                     
#>  nlme            3.1-148    2020-05-24 [1] CRAN (R 4.0.0)                     
#>  nloptr          1.2.2.1    2020-03-11 [1] CRAN (R 4.0.0)                     
#>  nnet            7.3-14     2020-04-26 [1] CRAN (R 4.0.0)                     
#>  parsnip       * 0.1.1.9000 2020-06-19 [1] Github (tidymodels/parsnip@3671e19)
#>  pillar          1.4.4      2020-05-05 [1] CRAN (R 4.0.0)                     
#>  pkgbuild        1.0.8      2020-05-07 [1] CRAN (R 4.0.0)                     
#>  pkgconfig       2.0.3      2019-09-22 [1] CRAN (R 4.0.0)                     
#>  pkgload         1.1.0      2020-05-29 [1] CRAN (R 4.0.0)                     
#>  plyr            1.8.6      2020-03-03 [1] CRAN (R 4.0.0)                     
#>  prettyunits     1.1.1      2020-01-24 [1] CRAN (R 4.0.0)                     
#>  pROC            1.16.2     2020-03-19 [1] CRAN (R 4.0.0)                     
#>  processx        3.4.2      2020-02-09 [1] CRAN (R 4.0.0)                     
#>  prodlim         2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)                     
#>  promises        1.1.1      2020-06-09 [1] CRAN (R 4.0.0)                     
#>  ps              1.3.3      2020-05-08 [1] CRAN (R 4.0.0)                     
#>  purrr         * 0.3.4      2020-04-17 [1] CRAN (R 4.0.0)                     
#>  R6              2.4.1      2019-11-12 [1] CRAN (R 4.0.0)                     
#>  Rcpp            1.0.4.6    2020-04-09 [1] CRAN (R 4.0.0)                     
#>  RcppParallel    5.0.1      2020-05-06 [1] CRAN (R 4.0.0)                     
#>  recipes       * 0.1.12     2020-05-01 [1] CRAN (R 4.0.0)                     
#>  remotes         2.1.1      2020-02-15 [1] CRAN (R 4.0.0)                     
#>  reshape2        1.4.4      2020-04-09 [1] CRAN (R 4.0.0)                     
#>  rlang           0.4.6      2020-05-02 [1] CRAN (R 4.0.0)                     
#>  rmarkdown       2.3        2020-06-18 [1] CRAN (R 4.0.0)                     
#>  rpart           4.1-15     2019-04-12 [1] CRAN (R 4.0.0)                     
#>  rprojroot       1.3-2      2018-01-03 [1] CRAN (R 4.0.0)                     
#>  rsample       * 0.0.7      2020-06-04 [1] CRAN (R 4.0.0)                     
#>  rsconnect       0.8.16     2019-12-13 [1] CRAN (R 4.0.0)                     
#>  rstan           2.19.3     2020-02-11 [1] CRAN (R 4.0.0)                     
#>  rstanarm        2.19.3     2020-02-11 [1] CRAN (R 4.0.0)                     
#>  rstantools      2.0.0      2019-09-15 [1] CRAN (R 4.0.0)                     
#>  rstudioapi      0.11       2020-02-07 [1] CRAN (R 4.0.0)                     
#>  scales        * 1.1.1      2020-05-11 [1] CRAN (R 4.0.0)                     
#>  sessioninfo     1.1.1      2018-11-05 [1] CRAN (R 4.0.0)                     
#>  shiny           1.4.0.2    2020-03-13 [1] CRAN (R 4.0.0)                     
#>  shinyjs         1.1        2020-01-13 [1] CRAN (R 4.0.0)                     
#>  shinystan       2.5.0      2018-05-01 [1] CRAN (R 4.0.0)                     
#>  shinythemes     1.1.2      2018-11-06 [1] CRAN (R 4.0.0)                     
#>  SnowballC       0.7.0      2020-04-01 [1] CRAN (R 4.0.0)                     
#>  StanHeaders     2.21.0-5   2020-06-09 [1] CRAN (R 4.0.0)                     
#>  statmod         1.4.34     2020-02-17 [1] CRAN (R 4.0.0)                     
#>  stringi         1.4.6      2020-02-17 [1] CRAN (R 4.0.0)                     
#>  stringr         1.4.0      2019-02-10 [1] CRAN (R 4.0.0)                     
#>  survival        3.2-3      2020-06-13 [1] CRAN (R 4.0.0)                     
#>  testthat        2.3.2      2020-03-02 [1] CRAN (R 4.0.0)                     
#>  threejs         0.3.3      2020-01-21 [1] CRAN (R 4.0.0)                     
#>  tibble        * 3.0.1      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  tidymodels    * 0.1.0      2020-02-16 [1] CRAN (R 4.0.0)                     
#>  tidyposterior   0.0.3      2020-06-11 [1] CRAN (R 4.0.0)                     
#>  tidypredict     0.4.5      2020-02-10 [1] CRAN (R 4.0.0)                     
#>  tidyr           1.1.0      2020-05-20 [1] CRAN (R 4.0.0)                     
#>  tidyselect      1.1.0      2020-05-11 [1] CRAN (R 4.0.0)                     
#>  tidytext        0.2.4      2020-04-17 [1] CRAN (R 4.0.0)                     
#>  timeDate        3043.102   2018-02-21 [1] CRAN (R 4.0.0)                     
#>  tokenizers      0.2.1      2018-03-29 [1] CRAN (R 4.0.0)                     
#>  tune          * 0.1.0      2020-04-02 [1] CRAN (R 4.0.0)                     
#>  usethis         1.6.1      2020-04-29 [1] CRAN (R 4.0.0)                     
#>  utf8            1.1.4      2018-05-24 [1] CRAN (R 4.0.0)                     
#>  vctrs           0.3.1      2020-06-05 [1] CRAN (R 4.0.0)                     
#>  withr           2.2.0      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  workflows     * 0.1.1      2020-03-17 [1] CRAN (R 4.0.0)                     
#>  xfun            0.14       2020-05-20 [1] CRAN (R 4.0.0)                     
#>  xgboost         1.1.1.1    2020-06-14 [1] CRAN (R 4.0.0)                     
#>  xtable          1.8-4      2019-04-21 [1] CRAN (R 4.0.0)                     
#>  xts             0.12-0     2020-01-19 [1] CRAN (R 4.0.0)                     
#>  yaml            2.2.1      2020-02-01 [1] CRAN (R 4.0.0)                     
#>  yardstick     * 0.0.6      2020-03-17 [1] CRAN (R 4.0.0)                     
#>  zoo             1.8-8      2020-05-02 [1] CRAN (R 4.0.0)                     
#> 
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behaviornext release 🚀

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions