Skip to content

parsnip doesn't patch params argument with xgboost engine in boost_tree #774

@john-b-edwards

Description

@john-b-edwards

The problem

I'm having trouble with passing parameters to set_engine for xgboost and getting tune_bayes to work. I hadn't encountered this kind of issue before despite implementing this kind of logic in other code, so my supposition is that it's something to do either with a new version of xgboost or some change in parnsip.

Reproducible example

This code works:

library(tidymodels)
rec <- recipe(vs ~ ., data = mtcars |>
                mutate(vs = as.factor(vs)))

folds <- vfold_cv(mtcars |>
                    mutate(vs = as.factor(vs)), v = 2)

xgb_spec <- boost_tree(
  trees = 25,
  tree_depth = tune(),
  min_n = tune(),
  loss_reduction = tune(),
  sample_size = tune(),
  mtry = tune(),
  learn_rate = tune(),
)

metrics <- metric_set(mn_log_loss)

std_xgb_spec <- xgb_spec |>
  set_engine("xgboost") |>
  set_mode("classification")

wf <- workflow() |>
  add_model(std_xgb_spec) |>
  add_recipe(rec)

params <- hardhat::extract_parameter_set_dials(wf) |>
  finalize(mtcars)

set.seed(123)
res <- wf |>
  tune_bayes(
    resamples = folds,
    param_info = params,
    initial = 3,
    iter = 10,
    metrics = metrics,
    control = control_bayes(
      no_improve = 10,
      verbose = F,
      time_limit = 240
    )
)
#> ! There are 6 tuning parameters and 3 grid points were requested.
#> • There are more tuning parameters than there are initial points. This is likely to cause numerical issues in the first few search iterations.
#> ! Fold1: preprocessor 1/1, model 2/3: 38 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold1: preprocessor 1/1, model 3/3: 22 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold2: preprocessor 1/1, model 2/3: 38 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold2: preprocessor 1/1, model 3/3: 22 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! The Gaussian process model is being fit using 6 features but only has 3
#>   data points to do so. This may cause errors or a poor model fit.
#> 
#> ! The Gaussian process model is being fit using 6 features but only has 4
#>   data points to do so. This may cause errors or a poor model fit.
#> 
#> ! Fold1: preprocessor 1/1, model 1/1: 34 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold2: preprocessor 1/1, model 1/1: 34 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! The Gaussian process model is being fit using 6 features but only has 5
#>   data points to do so. This may cause errors or a poor model fit.
#> 
#> ! The Gaussian process model is being fit using 6 features but only has 6
#>   data points to do so. This may cause errors or a poor model fit.
#> 
#> ! The Gaussian process model is being fit using 6 features but only has 7
#>   data points to do so. This may cause errors or a poor model fit.

But in explicitly setting parameters, the gaussian process fails:

library(tidymodels)
rec <- recipe(vs ~ ., data = mtcars |>
                mutate(vs = as.factor(vs)))

folds <- vfold_cv(mtcars |>
                    mutate(vs = as.factor(vs)), v = 2)

xgb_spec <- boost_tree(
  trees = 25,
  tree_depth = tune(),
  min_n = tune(),
  loss_reduction = tune(),
  sample_size = tune(),
  mtry = tune(),
  learn_rate = tune(),
)

metrics <- metric_set(mn_log_loss)

std_xgb_spec <- xgb_spec |>
  set_engine("xgboost",
             params = list(
               tree_method = "hist",
               eval_metric = "logloss",
               objective = "binary:logistic")) |>
  set_mode("classification")

wf <- workflow() |>
  add_model(std_xgb_spec) |>
  add_recipe(rec)

params <- hardhat::extract_parameter_set_dials(wf) |>
  finalize(mtcars)

set.seed(123)
res <- wf |>
  tune_bayes(
    resamples = folds,
    param_info = params,
    initial = 3,
    iter = 10,
    metrics = metrics,
    control = control_bayes(
      no_improve = 10,
      verbose = F,
      time_limit = 240
    )
)
#> ! There are 6 tuning parameters and 3 grid points were requested.
#> • There are more tuning parameters than there are initial points. This is likely to cause numerical issues in the first few search iterations.
#> ! Fold1: preprocessor 1/1, model 2/3: 38 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold1: preprocessor 1/1, model 3/3: 22 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold2: preprocessor 1/1, model 2/3: 38 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! Fold2: preprocessor 1/1, model 3/3: 22 samples were requested but there were 16 rows in the data. 16 will be...
#> 
#> ! All of the mn_log_loss values were identical. The Gaussian process model
#>   cannot be fit to the data. Try expanding the range of the tuning
#>   parameters.
#> 
#> ! The Gaussian process model is being fit using 6 features but only has 3
#>   data points to do so. This may cause errors or a poor model fit.
#> 
#> x Gaussian process model:
#>   Error in GP_deviance(beta = row, X = X, Y = Y, nug_thres = nug_thres, ...
#>               unable to find optimum parameters
#> Error in `check_gp_failure()`:
#> ! Gaussian process model was not fit.
#> ✖ Optimization stopped prematurely; returning current results.

This is my session info:

> sessionInfo()
R version 4.2.1 (2022-06-23 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 22000)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.utf8  LC_CTYPE=English_United States.utf8   
[3] LC_MONETARY=English_United States.utf8 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.utf8    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] xgboost_1.6.0.1       yardstick_1.0.0       workflowsets_1.0.0    workflows_1.0.0      
 [5] tune_1.0.0            tidyr_1.2.0           tibble_3.1.7          rsample_1.0.0        
 [9] recipes_1.0.1         purrr_0.3.4           parsnip_1.0.0         modeldata_1.0.0      
[13] infer_1.0.2           ggplot2_3.3.6         dplyr_1.0.9           dials_1.0.0          
[17] scales_1.2.0          broom_1.0.0           tidymodels_1.0.0.9000

loaded via a namespace (and not attached):
 [1] fs_1.5.2           usethis_2.1.6      lubridate_1.8.0    devtools_2.4.3    
 [5] DiceDesign_1.9     rprojroot_2.0.3    tools_4.2.1        backports_1.4.1   
 [9] utf8_1.2.2         R6_2.5.1           rpart_4.1.16       DBI_1.1.3         
[13] colorspace_2.0-3   nnet_7.3-17        withr_2.5.0        tidyselect_1.1.2  
[17] prettyunits_1.1.1  processx_3.7.0     curl_4.3.2         compiler_4.2.1    
[21] cli_3.3.0          callr_3.7.1        digest_0.6.29      rmarkdown_2.14    
[25] htmltools_0.5.3    pkgconfig_2.0.3    parallelly_1.32.1  sessioninfo_1.2.2 
[29] lhs_1.1.5          highr_0.9          fastmap_1.1.0      rlang_1.0.4       
[33] rstudioapi_0.13    generics_0.1.3     jsonlite_1.8.0     magrittr_2.0.3    
[37] Matrix_1.4-1       Rcpp_1.0.9         munsell_0.5.0      fansi_1.0.3       
[41] GPfit_1.0-8        clipr_0.8.0        lifecycle_1.0.1    furrr_0.3.0       
[45] yaml_2.3.5         MASS_7.3-57        pkgbuild_1.3.1     grid_4.2.1        
[49] parallel_4.2.1     listenv_0.8.0      crayon_1.5.1       lattice_0.20-45   
[53] splines_4.2.1      knitr_1.39         ps_1.7.1           pillar_1.8.0      
[57] future.apply_1.9.0 codetools_0.2-18   pkgload_1.3.0      reprex_2.0.1      
[61] glue_1.6.2         evaluate_0.15      data.table_1.14.2  remotes_2.4.2     
[65] vctrs_0.4.1        foreach_1.5.2      gtable_0.3.0       future_1.26.1     
[69] assertthat_0.2.1   cachem_1.0.6       xfun_0.31          gower_1.0.0       
[73] prodlim_2019.11.13 class_7.3-20       survival_3.3-1     timeDate_4021.104 
[77] iterators_1.0.14   memoise_2.0.1      hardhat_1.2.0      lava_1.6.10       
[81] globals_0.15.1     ellipsis_0.3.2     ipred_0.9-13   

if it's at all relevant, I am using the GPU-enabled windows version of xgboost 1.6.0 from here.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions