Skip to content

multi_predict not use recipe #1250

@joscani

Description

@joscani

Hi. I'm trying multi_predict . I copy my code

library(tidymodels) # modelling framework
library(workflows)
library(bonsai)     # models like lightgm

bin_roughly <- function(x) {
  n_levels <- sample(1:4, 1)
  cutpoints <- sort(sample(x, n_levels))
  x <- rowSums(vapply(cutpoints, `>`, logical(length(x)),  x))
  factor(x, labels = paste0("level_", 1:(n_levels+1)))
}

simulate_regression <- function(n_rows) {
  modeldata::sim_regression(n_rows) %>%
    select(-c(predictor_16:predictor_20)) %>%
    mutate(across(contains("_1"), bin_roughly))
}

simulate_classification <- function(n_rows, n_levels) {
  modeldata::sim_classification(n_rows, num_linear = 12) %>%
    mutate(across(contains("_1"), bin_roughly))
}






set.seed(1)
d <- simulate_classification(1e3)
d
d_split <- initial_split(d)
d_train <- training(d_split)
d_test <- testing(d_split)



mod1_spec <- 
  boost_tree( trees = 100, learn_rate = 0.1)  |> 
  set_mode("classification")  |> 
  set_engine(engine = "xgboost")

recipe1 <- recipe(
                  class ~ ., 
                  data = d_train)  |>
        step_dummy(all_nominal_predictors())  |> 
        prep()

d_train_bake <- bake(recipe1, d_train)


wf1_fit <- fit(mod1_spec,formula = class ~ .,  d_train_bake)

I can predict with

predict(wf1_fit, new_data = d_train_bake, type = "prob")
# A tibble: 750 × 2
   .pred_class_1 .pred_class_2
           <dbl>         <dbl>
 1        0.624         0.376 
 2        0.984         0.0160
 3        0.950         0.0496
 4        0.0104        0.990 
 5        0.931         0.0690
 6        0.980         0.0204
 7        0.0976        0.902 
 8        0.150         0.850 
 9        0.983         0.0165
10        0.577         0.423 
# ℹ 740 more rows
# ℹ Use `print(n = ...)` to see more rows

But I get an error using multi_predict

pred_10_trees <- multi_predict(wf1_fit, new_data = d_train_bake, trees = 10 )
Error in `map()`:In index: 1.
Caused by error in `maybe_matrix()` at parsnip/R/boost_tree.R:397:5:
! Some columns are non-numeric. The data cannot be converted to numeric matr
ix: 'class'.
Run `rlang::last_trace()` to see where the error occurred.

Any idea? Thanks

Session info

> sessionInfo()
R version 4.4.2 (2024-10-31)
Platform: x86_64-pc-linux-gnu
Running under: Linux Mint 21.3

Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/libmkl_rt.so;  LAPACK version 3.8.0

locale:
 [1] LC_CTYPE=es_ES.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=es_ES.UTF-8        LC_COLLATE=es_ES.UTF-8    
 [5] LC_MONETARY=es_ES.UTF-8    LC_MESSAGES=es_ES.UTF-8   
 [7] LC_PAPER=es_ES.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=es_ES.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Madrid
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices datasets  utils     methods   base     

other attached packages:
 [1] xgboost_1.7.8.1    bonsai_0.3.1       yardstick_1.3.1   
 [4] workflowsets_1.1.0 workflows_1.1.4    tune_1.2.1        
 [7] tidyr_1.3.1        tibble_3.2.1       rsample_1.2.1     
[10] recipes_1.1.0      purrr_1.0.2        parsnip_1.2.1     
[13] modeldata_1.4.0    infer_1.0.7        ggplot2_3.5.1     
[16] dplyr_1.1.4        dials_1.3.0        scales_1.3.0      
[19] broom_1.0.7        tidymodels_1.2.0   nvimcom_0.9.50    

loaded via a namespace (and not attached):
 [1] gtable_0.3.6        lattice_0.22-6      vctrs_0.6.5        
 [4] tools_4.4.2         generics_0.1.3      parallel_4.4.2     
 [7] pkgconfig_2.0.3     Matrix_1.7-1        data.table_1.16.4  
[10] lhs_1.2.0           GPfit_1.0-8         lifecycle_1.0.4    
[13] compiler_4.4.2      tictoc_1.2.1        munsell_0.5.1      
[16] codetools_0.2-20    DiceDesign_1.10     class_7.3-23       
[19] yaml_2.3.10         prodlim_2024.06.25  modelenv_0.2.0     
[22] pillar_1.10.1       furrr_0.3.1         MASS_7.3-61        
[25] gower_1.0.2         iterators_1.0.14    rpart_4.1.23       
[28] foreach_1.5.2       parallelly_1.41.0   lava_1.8.0         
[31] tidyselect_1.2.1    digest_0.6.37       future_1.34.0      
[34] listenv_0.9.1       splines_4.4.2       grid_4.4.2         
[37] colorspace_2.1-1    cli_3.6.3           magrittr_2.0.3     
[40] utf8_1.2.4          survival_3.8-3      future.apply_1.11.3
[43] withr_3.0.2         backports_1.5.0     lubridate_1.9.4    
[46] timechange_0.3.0    globals_0.16.3      nnet_7.3-19        
[49] timeDate_4041.110   hardhat_1.4.0       rlang_1.1.4        
[52] Rcpp_1.0.13-1       glue_1.8.0          BiocManager_1.30.25
[55] renv_1.0.11         ipred_0.9-15        jsonlite_1.8.9     
[58] rstudioapi_0.17.1   R6_2.5.1           

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions