The problem
I'm concerned with the validity of results from training XGBoost using tidymodels.
tidymodels uses the first level of a factor as the default positive class, while XGBoost defaults to the second (although I couldn't find this in the docs; didn't check source code). Using options(yardstick.event_first = FALSE) is not a viable solution, because of #240 and tidymodels/yardstick#166, meaning it's possible select_best() won't select the best set of parameters based on the PR-AUC.
The problem arises when butcher::butcher() the model fit for deployment.
Now, the underlying XGBoost model is predicting the negative model class.
Yes, I can simply 1 - prediction to get the probability of the positive class, but I'm not sure if that is reliable.
Was XGBoost in fact trained to predict the negative class? If so, how reliable are the metrics generated by tidymodels and the tuning process and the final predictions?
Also note that, on my actual data with severe class imbalance, the evaluation log from XGBoost has a PR-AUC of 98 on the first boosting iteration, which also tells me it's targeting the wrong class.
Posting here because #240, but this may not be a tune bug (or an issue at all, for that matter).
Reproducible example
suppressPackageStartupMessages({
library(tidymodels)
library(MLmetrics)
library(mlbench)
library(dplyr)
library(forcats)
library(butcher)
})
data("PimaIndiansDiabetes2")
set.seed(73487)
# remove duplicates
data <- PimaIndiansDiabetes2[complete.cases(PimaIndiansDiabetes2),] %>%
# flip levels to avoid issue 240
mutate(diabetes = forcats::fct_relevel(diabetes, c("pos", "neg")))
xgb_recipe <- recipe(diabetes ~ ., data = data)
xgb_model <- parsnip::boost_tree(
trees = tune::tune(),
tree_depth = tune::tune(), min_n = tune::tune(),
loss_reduction = tune::tune(),
sample_size = tune::tune(), mtry = tune::tune(),
learn_rate = tune::tune(),
) %>%
parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr")) %>%
parsnip::set_mode(., mode = "classification")
xgb_grid <- dials::grid_max_entropy(
dials::trees(c(200, 1500)),
dials::min_n(),
dials::tree_depth(c(3, 10)),
dials::loss_reduction(),
sample_size = dials::sample_prop(c(0.5, 1)),
dials::finalize(dials::mtry(), recipes::prep(xgb_recipe) %>% recipes::juice()),
dials::learn_rate(c(0.01, 0.3), trans = NULL),
size = 10
)
xgb_workflow <- workflows::workflow() %>%
workflows::add_model(xgb_model) %>%
workflows::add_recipe(xgb_recipe)
cv_folds <- rsample::vfold_cv(
data,
strata = diabetes,
v = 3
)
xgb_tuned_fit <- tune::tune_grid(
xgb_workflow,
resamples = cv_folds,
grid = xgb_grid,
metrics = yardstick::metric_set(pr_auc),
control = tune::control_grid(save_pred = TRUE, verbose = TRUE)
)
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> i Fold1: recipe
#> ✓ Fold1: recipe
#> i Fold1: model 1/10
#> ✓ Fold1: model 1/10
#> i Fold1: model 1/10 (predictions)
#> i Fold1: model 2/10
#> ✓ Fold1: model 2/10
#> i Fold1: model 2/10 (predictions)
#> i Fold1: model 3/10
#> ✓ Fold1: model 3/10
#> i Fold1: model 3/10 (predictions)
#> i Fold1: model 4/10
#> ✓ Fold1: model 4/10
#> i Fold1: model 4/10 (predictions)
#> i Fold1: model 5/10
#> ✓ Fold1: model 5/10
#> i Fold1: model 5/10 (predictions)
#> i Fold1: model 6/10
#> ✓ Fold1: model 6/10
#> i Fold1: model 6/10 (predictions)
#> i Fold1: model 7/10
#> ✓ Fold1: model 7/10
#> i Fold1: model 7/10 (predictions)
#> i Fold1: model 8/10
#> ✓ Fold1: model 8/10
#> i Fold1: model 8/10 (predictions)
#> i Fold1: model 9/10
#> ✓ Fold1: model 9/10
#> i Fold1: model 9/10 (predictions)
#> i Fold1: model 10/10
#> ✓ Fold1: model 10/10
#> i Fold1: model 10/10 (predictions)
#> i Fold2: recipe
#> ✓ Fold2: recipe
#> i Fold2: model 1/10
#> ✓ Fold2: model 1/10
#> i Fold2: model 1/10 (predictions)
#> i Fold2: model 2/10
#> ✓ Fold2: model 2/10
#> i Fold2: model 2/10 (predictions)
#> i Fold2: model 3/10
#> ✓ Fold2: model 3/10
#> i Fold2: model 3/10 (predictions)
#> i Fold2: model 4/10
#> ✓ Fold2: model 4/10
#> i Fold2: model 4/10 (predictions)
#> i Fold2: model 5/10
#> ✓ Fold2: model 5/10
#> i Fold2: model 5/10 (predictions)
#> i Fold2: model 6/10
#> ✓ Fold2: model 6/10
#> i Fold2: model 6/10 (predictions)
#> i Fold2: model 7/10
#> ✓ Fold2: model 7/10
#> i Fold2: model 7/10 (predictions)
#> i Fold2: model 8/10
#> ✓ Fold2: model 8/10
#> i Fold2: model 8/10 (predictions)
#> i Fold2: model 9/10
#> ✓ Fold2: model 9/10
#> i Fold2: model 9/10 (predictions)
#> i Fold2: model 10/10
#> ✓ Fold2: model 10/10
#> i Fold2: model 10/10 (predictions)
#> i Fold3: recipe
#> ✓ Fold3: recipe
#> i Fold3: model 1/10
#> ✓ Fold3: model 1/10
#> i Fold3: model 1/10 (predictions)
#> i Fold3: model 2/10
#> ✓ Fold3: model 2/10
#> i Fold3: model 2/10 (predictions)
#> i Fold3: model 3/10
#> ✓ Fold3: model 3/10
#> i Fold3: model 3/10 (predictions)
#> i Fold3: model 4/10
#> ✓ Fold3: model 4/10
#> i Fold3: model 4/10 (predictions)
#> i Fold3: model 5/10
#> ✓ Fold3: model 5/10
#> i Fold3: model 5/10 (predictions)
#> i Fold3: model 6/10
#> ✓ Fold3: model 6/10
#> i Fold3: model 6/10 (predictions)
#> i Fold3: model 7/10
#> ✓ Fold3: model 7/10
#> i Fold3: model 7/10 (predictions)
#> i Fold3: model 8/10
#> ✓ Fold3: model 8/10
#> i Fold3: model 8/10 (predictions)
#> i Fold3: model 9/10
#> ✓ Fold3: model 9/10
#> i Fold3: model 9/10 (predictions)
#> i Fold3: model 10/10
#> ✓ Fold3: model 10/10
#> i Fold3: model 10/10 (predictions)
tune::collect_metrics(xgb_tuned_fit) %>%
select(mean) %>%
head(2)
#> # A tibble: 2 x 1
#> mean
#> <dbl>
#> 1 0.666
#> 2 0.665
best_model_params <- tune::select_best(xgb_tuned_fit, "pr_auc")
final_xgb <- tune::finalize_workflow(
xgb_workflow,
best_model_params
)
final_fit <- final_xgb %>%
fit(data = data)
xgb_fit <- pull_workflow_fit(final_fit) %>%
butcher::butcher()
n <- xgb_fit$feature_names
predict(final_fit, data, type = "prob") %>%
head()
#> # A tibble: 6 x 2
#> .pred_pos .pred_neg
#> <dbl> <dbl>
#> 1 0.000576 0.999
#> 2 0.997 0.00278
#> 3 0.960 0.0405
#> 4 0.994 0.00607
#> 5 0.994 0.00624
#> 6 0.989 0.0112
# xgboost predicts second class "neg"
predict(xgb_fit, data.matrix(data %>% select(n))) %>%
head()
#> Note: Using an external vector in selections is ambiguous.
#> ℹ Use `all_of(n)` instead of `n` to silence this message.
#> ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>.
#> This message is displayed once per session.
#> [1] 0.999423981 0.002777288 0.040465232 0.006066449 0.006243575 0.011152640
Created on 2020-07-09 by the reprex package (v0.3.0)
Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.0 (2020-04-24)
#> os macOS Catalina 10.15.5
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Berlin
#> date 2020-07-09
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.0)
#> backports 1.1.7 2020-05-13 [1] CRAN (R 4.0.0)
#> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.0.0)
#> bayesplot 1.7.2 2020-05-28 [1] CRAN (R 4.0.0)
#> boot 1.3-25 2020-04-26 [1] CRAN (R 4.0.0)
#> broom * 0.5.6 2020-04-20 [1] CRAN (R 4.0.0)
#> butcher * 0.1.2 2020-01-23 [1] CRAN (R 4.0.0)
#> callr 3.4.3 2020-03-28 [1] CRAN (R 4.0.0)
#> class 7.3-17 2020-04-26 [1] CRAN (R 4.0.0)
#> cli 2.0.2 2020-02-28 [1] CRAN (R 4.0.0)
#> codetools 0.2-16 2018-12-24 [1] CRAN (R 4.0.0)
#> colorspace 1.4-1 2019-03-18 [1] CRAN (R 4.0.0)
#> colourpicker 1.0 2017-09-27 [1] CRAN (R 4.0.0)
#> crayon 1.3.4 2017-09-16 [1] CRAN (R 4.0.0)
#> crosstalk 1.1.0.1 2020-03-13 [1] CRAN (R 4.0.0)
#> data.table 1.12.8 2019-12-09 [1] CRAN (R 4.0.0)
#> desc 1.2.0 2018-05-01 [1] CRAN (R 4.0.0)
#> devtools 2.3.0 2020-04-10 [1] CRAN (R 4.0.0)
#> dials * 0.0.7 2020-06-10 [1] CRAN (R 4.0.0)
#> DiceDesign 1.8-1 2019-07-31 [1] CRAN (R 4.0.0)
#> digest 0.6.25 2020-02-23 [1] CRAN (R 4.0.0)
#> dplyr * 1.0.0 2020-05-29 [1] CRAN (R 4.0.0)
#> DT 0.13 2020-03-23 [1] CRAN (R 4.0.0)
#> dygraphs 1.1.1.6 2018-07-11 [1] CRAN (R 4.0.0)
#> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.0)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 4.0.0)
#> fansi 0.4.1 2020-01-08 [1] CRAN (R 4.0.0)
#> fastmap 1.0.1 2019-10-08 [1] CRAN (R 4.0.0)
#> forcats * 0.5.0 2020-03-01 [1] CRAN (R 4.0.0)
#> foreach 1.5.0 2020-03-30 [1] CRAN (R 4.0.0)
#> fs 1.4.1 2020-04-04 [1] CRAN (R 4.0.0)
#> furrr 0.1.0 2018-05-16 [1] CRAN (R 4.0.0)
#> future 1.17.0 2020-04-18 [1] CRAN (R 4.0.0)
#> generics 0.0.2 2018-11-29 [1] CRAN (R 4.0.0)
#> ggplot2 * 3.3.2 2020-06-19 [1] CRAN (R 4.0.0)
#> ggridges 0.5.2 2020-01-12 [1] CRAN (R 4.0.0)
#> globals 0.12.5 2019-12-07 [1] CRAN (R 4.0.0)
#> glue 1.4.1 2020-05-13 [1] CRAN (R 4.0.0)
#> gower 0.2.1 2019-05-14 [1] CRAN (R 4.0.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.0.0)
#> gridExtra 2.3 2017-09-09 [1] CRAN (R 4.0.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.0.0)
#> gtools 3.8.2 2020-03-31 [1] CRAN (R 4.0.0)
#> hardhat 0.1.3 2020-05-20 [1] CRAN (R 4.0.0)
#> highr 0.8 2019-03-20 [1] CRAN (R 4.0.0)
#> htmltools 0.4.0 2019-10-04 [1] CRAN (R 4.0.0)
#> htmlwidgets 1.5.1 2019-10-08 [1] CRAN (R 4.0.0)
#> httpuv 1.5.4 2020-06-06 [1] CRAN (R 4.0.0)
#> igraph 1.2.5 2020-03-19 [1] CRAN (R 4.0.0)
#> infer * 0.5.2 2020-06-14 [1] CRAN (R 4.0.0)
#> inline 0.3.15 2018-05-18 [1] CRAN (R 4.0.0)
#> ipred 0.9-9 2019-04-28 [1] CRAN (R 4.0.0)
#> iterators 1.0.12 2019-07-26 [1] CRAN (R 4.0.0)
#> janeaustenr 0.1.5 2017-06-10 [1] CRAN (R 4.0.0)
#> knitr 1.28 2020-02-06 [1] CRAN (R 4.0.0)
#> later 1.1.0.1 2020-06-05 [1] CRAN (R 4.0.0)
#> lattice 0.20-41 2020-04-02 [1] CRAN (R 4.0.0)
#> lava 1.6.7 2020-03-05 [1] CRAN (R 4.0.0)
#> lhs 1.0.2 2020-04-13 [1] CRAN (R 4.0.0)
#> lifecycle 0.2.0 2020-03-06 [1] CRAN (R 4.0.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.0.0)
#> lme4 1.1-23 2020-04-07 [1] CRAN (R 4.0.0)
#> loo 2.2.0 2019-12-19 [1] CRAN (R 4.0.0)
#> lubridate 1.7.9 2020-06-08 [1] CRAN (R 4.0.0)
#> magrittr 1.5 2014-11-22 [1] CRAN (R 4.0.0)
#> markdown 1.1 2019-08-07 [1] CRAN (R 4.0.0)
#> MASS 7.3-51.6 2020-04-26 [1] CRAN (R 4.0.0)
#> Matrix 1.2-18 2019-11-27 [1] CRAN (R 4.0.0)
#> matrixStats 0.56.0 2020-03-13 [1] CRAN (R 4.0.0)
#> memoise 1.1.0 2017-04-21 [1] CRAN (R 4.0.0)
#> mime 0.9 2020-02-04 [1] CRAN (R 4.0.0)
#> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.0.0)
#> minqa 1.2.4 2014-10-09 [1] CRAN (R 4.0.0)
#> mlbench * 2.1-1 2012-07-10 [1] CRAN (R 4.0.0)
#> MLmetrics * 1.1.1 2016-05-13 [1] CRAN (R 4.0.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.0.0)
#> nlme 3.1-148 2020-05-24 [1] CRAN (R 4.0.0)
#> nloptr 1.2.2.1 2020-03-11 [1] CRAN (R 4.0.0)
#> nnet 7.3-14 2020-04-26 [1] CRAN (R 4.0.0)
#> parsnip * 0.1.1.9000 2020-06-19 [1] Github (tidymodels/parsnip@3671e19)
#> pillar 1.4.4 2020-05-05 [1] CRAN (R 4.0.0)
#> pkgbuild 1.0.8 2020-05-07 [1] CRAN (R 4.0.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.0)
#> pkgload 1.1.0 2020-05-29 [1] CRAN (R 4.0.0)
#> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.0.0)
#> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.0.0)
#> pROC 1.16.2 2020-03-19 [1] CRAN (R 4.0.0)
#> processx 3.4.2 2020-02-09 [1] CRAN (R 4.0.0)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)
#> promises 1.1.1 2020-06-09 [1] CRAN (R 4.0.0)
#> ps 1.3.3 2020-05-08 [1] CRAN (R 4.0.0)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.0.0)
#> R6 2.4.1 2019-11-12 [1] CRAN (R 4.0.0)
#> Rcpp 1.0.4.6 2020-04-09 [1] CRAN (R 4.0.0)
#> RcppParallel 5.0.1 2020-05-06 [1] CRAN (R 4.0.0)
#> recipes * 0.1.12 2020-05-01 [1] CRAN (R 4.0.0)
#> remotes 2.1.1 2020-02-15 [1] CRAN (R 4.0.0)
#> reshape2 1.4.4 2020-04-09 [1] CRAN (R 4.0.0)
#> rlang 0.4.6 2020-05-02 [1] CRAN (R 4.0.0)
#> rmarkdown 2.3 2020-06-18 [1] CRAN (R 4.0.0)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.0.0)
#> rprojroot 1.3-2 2018-01-03 [1] CRAN (R 4.0.0)
#> rsample * 0.0.7 2020-06-04 [1] CRAN (R 4.0.0)
#> rsconnect 0.8.16 2019-12-13 [1] CRAN (R 4.0.0)
#> rstan 2.19.3 2020-02-11 [1] CRAN (R 4.0.0)
#> rstanarm 2.19.3 2020-02-11 [1] CRAN (R 4.0.0)
#> rstantools 2.0.0 2019-09-15 [1] CRAN (R 4.0.0)
#> rstudioapi 0.11 2020-02-07 [1] CRAN (R 4.0.0)
#> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.0.0)
#> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 4.0.0)
#> shiny 1.4.0.2 2020-03-13 [1] CRAN (R 4.0.0)
#> shinyjs 1.1 2020-01-13 [1] CRAN (R 4.0.0)
#> shinystan 2.5.0 2018-05-01 [1] CRAN (R 4.0.0)
#> shinythemes 1.1.2 2018-11-06 [1] CRAN (R 4.0.0)
#> SnowballC 0.7.0 2020-04-01 [1] CRAN (R 4.0.0)
#> StanHeaders 2.21.0-5 2020-06-09 [1] CRAN (R 4.0.0)
#> statmod 1.4.34 2020-02-17 [1] CRAN (R 4.0.0)
#> stringi 1.4.6 2020-02-17 [1] CRAN (R 4.0.0)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.0)
#> survival 3.2-3 2020-06-13 [1] CRAN (R 4.0.0)
#> testthat 2.3.2 2020-03-02 [1] CRAN (R 4.0.0)
#> threejs 0.3.3 2020-01-21 [1] CRAN (R 4.0.0)
#> tibble * 3.0.1 2020-04-20 [1] CRAN (R 4.0.0)
#> tidymodels * 0.1.0 2020-02-16 [1] CRAN (R 4.0.0)
#> tidyposterior 0.0.3 2020-06-11 [1] CRAN (R 4.0.0)
#> tidypredict 0.4.5 2020-02-10 [1] CRAN (R 4.0.0)
#> tidyr 1.1.0 2020-05-20 [1] CRAN (R 4.0.0)
#> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.0)
#> tidytext 0.2.4 2020-04-17 [1] CRAN (R 4.0.0)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.0.0)
#> tokenizers 0.2.1 2018-03-29 [1] CRAN (R 4.0.0)
#> tune * 0.1.0 2020-04-02 [1] CRAN (R 4.0.0)
#> usethis 1.6.1 2020-04-29 [1] CRAN (R 4.0.0)
#> utf8 1.1.4 2018-05-24 [1] CRAN (R 4.0.0)
#> vctrs 0.3.1 2020-06-05 [1] CRAN (R 4.0.0)
#> withr 2.2.0 2020-04-20 [1] CRAN (R 4.0.0)
#> workflows * 0.1.1 2020-03-17 [1] CRAN (R 4.0.0)
#> xfun 0.14 2020-05-20 [1] CRAN (R 4.0.0)
#> xgboost 1.1.1.1 2020-06-14 [1] CRAN (R 4.0.0)
#> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.0.0)
#> xts 0.12-0 2020-01-19 [1] CRAN (R 4.0.0)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.0)
#> yardstick * 0.0.6 2020-03-17 [1] CRAN (R 4.0.0)
#> zoo 1.8-8 2020-05-02 [1] CRAN (R 4.0.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library
The problem
I'm concerned with the validity of results from training XGBoost using tidymodels.
tidymodels uses the first level of a factor as the default positive class, while XGBoost defaults to the second (although I couldn't find this in the docs; didn't check source code). Using
options(yardstick.event_first = FALSE)is not a viable solution, because of #240 and tidymodels/yardstick#166, meaning it's possibleselect_best()won't select the best set of parameters based on the PR-AUC.The problem arises when
butcher::butcher()the model fit for deployment.Now, the underlying XGBoost model is predicting the negative model class.
Yes, I can simply
1 - predictionto get the probability of the positive class, but I'm not sure if that is reliable.Was XGBoost in fact trained to predict the negative class? If so, how reliable are the metrics generated by tidymodels and the tuning process and the final predictions?
Also note that, on my actual data with severe class imbalance, the evaluation log from XGBoost has a PR-AUC of 98 on the first boosting iteration, which also tells me it's targeting the wrong class.
Posting here because #240, but this may not be a
tunebug (or an issue at all, for that matter).Reproducible example
suppressPackageStartupMessages({ library(tidymodels) library(MLmetrics) library(mlbench) library(dplyr) library(forcats) library(butcher) }) data("PimaIndiansDiabetes2") set.seed(73487) # remove duplicates data <- PimaIndiansDiabetes2[complete.cases(PimaIndiansDiabetes2),] %>% # flip levels to avoid issue 240 mutate(diabetes = forcats::fct_relevel(diabetes, c("pos", "neg"))) xgb_recipe <- recipe(diabetes ~ ., data = data) xgb_model <- parsnip::boost_tree( trees = tune::tune(), tree_depth = tune::tune(), min_n = tune::tune(), loss_reduction = tune::tune(), sample_size = tune::tune(), mtry = tune::tune(), learn_rate = tune::tune(), ) %>% parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr")) %>% parsnip::set_mode(., mode = "classification") xgb_grid <- dials::grid_max_entropy( dials::trees(c(200, 1500)), dials::min_n(), dials::tree_depth(c(3, 10)), dials::loss_reduction(), sample_size = dials::sample_prop(c(0.5, 1)), dials::finalize(dials::mtry(), recipes::prep(xgb_recipe) %>% recipes::juice()), dials::learn_rate(c(0.01, 0.3), trans = NULL), size = 10 ) xgb_workflow <- workflows::workflow() %>% workflows::add_model(xgb_model) %>% workflows::add_recipe(xgb_recipe) cv_folds <- rsample::vfold_cv( data, strata = diabetes, v = 3 ) xgb_tuned_fit <- tune::tune_grid( xgb_workflow, resamples = cv_folds, grid = xgb_grid, metrics = yardstick::metric_set(pr_auc), control = tune::control_grid(save_pred = TRUE, verbose = TRUE) ) #> i Creating pre-processing data to finalize unknown parameter: mtry #> i Fold1: recipe #> ✓ Fold1: recipe #> i Fold1: model 1/10 #> ✓ Fold1: model 1/10 #> i Fold1: model 1/10 (predictions) #> i Fold1: model 2/10 #> ✓ Fold1: model 2/10 #> i Fold1: model 2/10 (predictions) #> i Fold1: model 3/10 #> ✓ Fold1: model 3/10 #> i Fold1: model 3/10 (predictions) #> i Fold1: model 4/10 #> ✓ Fold1: model 4/10 #> i Fold1: model 4/10 (predictions) #> i Fold1: model 5/10 #> ✓ Fold1: model 5/10 #> i Fold1: model 5/10 (predictions) #> i Fold1: model 6/10 #> ✓ Fold1: model 6/10 #> i Fold1: model 6/10 (predictions) #> i Fold1: model 7/10 #> ✓ Fold1: model 7/10 #> i Fold1: model 7/10 (predictions) #> i Fold1: model 8/10 #> ✓ Fold1: model 8/10 #> i Fold1: model 8/10 (predictions) #> i Fold1: model 9/10 #> ✓ Fold1: model 9/10 #> i Fold1: model 9/10 (predictions) #> i Fold1: model 10/10 #> ✓ Fold1: model 10/10 #> i Fold1: model 10/10 (predictions) #> i Fold2: recipe #> ✓ Fold2: recipe #> i Fold2: model 1/10 #> ✓ Fold2: model 1/10 #> i Fold2: model 1/10 (predictions) #> i Fold2: model 2/10 #> ✓ Fold2: model 2/10 #> i Fold2: model 2/10 (predictions) #> i Fold2: model 3/10 #> ✓ Fold2: model 3/10 #> i Fold2: model 3/10 (predictions) #> i Fold2: model 4/10 #> ✓ Fold2: model 4/10 #> i Fold2: model 4/10 (predictions) #> i Fold2: model 5/10 #> ✓ Fold2: model 5/10 #> i Fold2: model 5/10 (predictions) #> i Fold2: model 6/10 #> ✓ Fold2: model 6/10 #> i Fold2: model 6/10 (predictions) #> i Fold2: model 7/10 #> ✓ Fold2: model 7/10 #> i Fold2: model 7/10 (predictions) #> i Fold2: model 8/10 #> ✓ Fold2: model 8/10 #> i Fold2: model 8/10 (predictions) #> i Fold2: model 9/10 #> ✓ Fold2: model 9/10 #> i Fold2: model 9/10 (predictions) #> i Fold2: model 10/10 #> ✓ Fold2: model 10/10 #> i Fold2: model 10/10 (predictions) #> i Fold3: recipe #> ✓ Fold3: recipe #> i Fold3: model 1/10 #> ✓ Fold3: model 1/10 #> i Fold3: model 1/10 (predictions) #> i Fold3: model 2/10 #> ✓ Fold3: model 2/10 #> i Fold3: model 2/10 (predictions) #> i Fold3: model 3/10 #> ✓ Fold3: model 3/10 #> i Fold3: model 3/10 (predictions) #> i Fold3: model 4/10 #> ✓ Fold3: model 4/10 #> i Fold3: model 4/10 (predictions) #> i Fold3: model 5/10 #> ✓ Fold3: model 5/10 #> i Fold3: model 5/10 (predictions) #> i Fold3: model 6/10 #> ✓ Fold3: model 6/10 #> i Fold3: model 6/10 (predictions) #> i Fold3: model 7/10 #> ✓ Fold3: model 7/10 #> i Fold3: model 7/10 (predictions) #> i Fold3: model 8/10 #> ✓ Fold3: model 8/10 #> i Fold3: model 8/10 (predictions) #> i Fold3: model 9/10 #> ✓ Fold3: model 9/10 #> i Fold3: model 9/10 (predictions) #> i Fold3: model 10/10 #> ✓ Fold3: model 10/10 #> i Fold3: model 10/10 (predictions) tune::collect_metrics(xgb_tuned_fit) %>% select(mean) %>% head(2) #> # A tibble: 2 x 1 #> mean #> <dbl> #> 1 0.666 #> 2 0.665 best_model_params <- tune::select_best(xgb_tuned_fit, "pr_auc") final_xgb <- tune::finalize_workflow( xgb_workflow, best_model_params ) final_fit <- final_xgb %>% fit(data = data) xgb_fit <- pull_workflow_fit(final_fit) %>% butcher::butcher() n <- xgb_fit$feature_names predict(final_fit, data, type = "prob") %>% head() #> # A tibble: 6 x 2 #> .pred_pos .pred_neg #> <dbl> <dbl> #> 1 0.000576 0.999 #> 2 0.997 0.00278 #> 3 0.960 0.0405 #> 4 0.994 0.00607 #> 5 0.994 0.00624 #> 6 0.989 0.0112 # xgboost predicts second class "neg" predict(xgb_fit, data.matrix(data %>% select(n))) %>% head() #> Note: Using an external vector in selections is ambiguous. #> ℹ Use `all_of(n)` instead of `n` to silence this message. #> ℹ See <https://tidyselect.r-lib.org/reference/faq-external-vector.html>. #> This message is displayed once per session. #> [1] 0.999423981 0.002777288 0.040465232 0.006066449 0.006243575 0.011152640Created on 2020-07-09 by the reprex package (v0.3.0)
Session info