-
Notifications
You must be signed in to change notification settings - Fork 93
Closed
Description
Hi. I'm trying multi_predict . I copy my code
library(tidymodels) # modelling framework
library(workflows)
library(bonsai) # models like lightgm
bin_roughly <- function(x) {
n_levels <- sample(1:4, 1)
cutpoints <- sort(sample(x, n_levels))
x <- rowSums(vapply(cutpoints, `>`, logical(length(x)), x))
factor(x, labels = paste0("level_", 1:(n_levels+1)))
}
simulate_regression <- function(n_rows) {
modeldata::sim_regression(n_rows) %>%
select(-c(predictor_16:predictor_20)) %>%
mutate(across(contains("_1"), bin_roughly))
}
simulate_classification <- function(n_rows, n_levels) {
modeldata::sim_classification(n_rows, num_linear = 12) %>%
mutate(across(contains("_1"), bin_roughly))
}
set.seed(1)
d <- simulate_classification(1e3)
d
d_split <- initial_split(d)
d_train <- training(d_split)
d_test <- testing(d_split)
mod1_spec <-
boost_tree( trees = 100, learn_rate = 0.1) |>
set_mode("classification") |>
set_engine(engine = "xgboost")
recipe1 <- recipe(
class ~ .,
data = d_train) |>
step_dummy(all_nominal_predictors()) |>
prep()
d_train_bake <- bake(recipe1, d_train)
wf1_fit <- fit(mod1_spec,formula = class ~ ., d_train_bake)
I can predict with
predict(wf1_fit, new_data = d_train_bake, type = "prob")
# A tibble: 750 × 2
.pred_class_1 .pred_class_2
<dbl> <dbl>
1 0.624 0.376
2 0.984 0.0160
3 0.950 0.0496
4 0.0104 0.990
5 0.931 0.0690
6 0.980 0.0204
7 0.0976 0.902
8 0.150 0.850
9 0.983 0.0165
10 0.577 0.423
# ℹ 740 more rows
# ℹ Use `print(n = ...)` to see more rows
But I get an error using multi_predict
pred_10_trees <- multi_predict(wf1_fit, new_data = d_train_bake, trees = 10 )
Error in `map()`:
ℹ In index: 1.
Caused by error in `maybe_matrix()` at parsnip/R/boost_tree.R:397:5:
! Some columns are non-numeric. The data cannot be converted to numeric matr
ix: 'class'.
Run `rlang::last_trace()` to see where the error occurred.
Any idea? Thanks
Session info
> sessionInfo()
R version 4.4.2 (2024-10-31)
Platform: x86_64-pc-linux-gnu
Running under: Linux Mint 21.3
Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/libmkl_rt.so; LAPACK version 3.8.0
locale:
[1] LC_CTYPE=es_ES.UTF-8 LC_NUMERIC=C
[3] LC_TIME=es_ES.UTF-8 LC_COLLATE=es_ES.UTF-8
[5] LC_MONETARY=es_ES.UTF-8 LC_MESSAGES=es_ES.UTF-8
[7] LC_PAPER=es_ES.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=es_ES.UTF-8 LC_IDENTIFICATION=C
time zone: Europe/Madrid
tzcode source: system (glibc)
attached base packages:
[1] stats graphics grDevices datasets utils methods base
other attached packages:
[1] xgboost_1.7.8.1 bonsai_0.3.1 yardstick_1.3.1
[4] workflowsets_1.1.0 workflows_1.1.4 tune_1.2.1
[7] tidyr_1.3.1 tibble_3.2.1 rsample_1.2.1
[10] recipes_1.1.0 purrr_1.0.2 parsnip_1.2.1
[13] modeldata_1.4.0 infer_1.0.7 ggplot2_3.5.1
[16] dplyr_1.1.4 dials_1.3.0 scales_1.3.0
[19] broom_1.0.7 tidymodels_1.2.0 nvimcom_0.9.50
loaded via a namespace (and not attached):
[1] gtable_0.3.6 lattice_0.22-6 vctrs_0.6.5
[4] tools_4.4.2 generics_0.1.3 parallel_4.4.2
[7] pkgconfig_2.0.3 Matrix_1.7-1 data.table_1.16.4
[10] lhs_1.2.0 GPfit_1.0-8 lifecycle_1.0.4
[13] compiler_4.4.2 tictoc_1.2.1 munsell_0.5.1
[16] codetools_0.2-20 DiceDesign_1.10 class_7.3-23
[19] yaml_2.3.10 prodlim_2024.06.25 modelenv_0.2.0
[22] pillar_1.10.1 furrr_0.3.1 MASS_7.3-61
[25] gower_1.0.2 iterators_1.0.14 rpart_4.1.23
[28] foreach_1.5.2 parallelly_1.41.0 lava_1.8.0
[31] tidyselect_1.2.1 digest_0.6.37 future_1.34.0
[34] listenv_0.9.1 splines_4.4.2 grid_4.4.2
[37] colorspace_2.1-1 cli_3.6.3 magrittr_2.0.3
[40] utf8_1.2.4 survival_3.8-3 future.apply_1.11.3
[43] withr_3.0.2 backports_1.5.0 lubridate_1.9.4
[46] timechange_0.3.0 globals_0.16.3 nnet_7.3-19
[49] timeDate_4041.110 hardhat_1.4.0 rlang_1.1.4
[52] Rcpp_1.0.13-1 glue_1.8.0 BiocManager_1.30.25
[55] renv_1.0.11 ipred_0.9-15 jsonlite_1.8.9
[58] rstudioapi_0.17.1 R6_2.5.1
Metadata
Metadata
Assignees
Labels
No labels