-
Notifications
You must be signed in to change notification settings - Fork 47
Description
The problem
Hi,
Thanks for a great package, I really appreciate all the work that y'all have put into this and the other tidymodels packages!
I believe I've found a bug in finalize_workflow() that gives incorrect values when finalizing multiple recipes steps of the same type. I ran into this when tuning the polynomial degree for several vars using step_poly(). Here's a minimal reprex.
Reproducible example
Sorry, couldn't get reprex::reprex() to work, but this code should run.
library(dplyr)
library(purrr)
library(recipes)
library(parsnip)
library(dials)
library(tune)
library(workflows)
lm_rec <-
recipe(form = mpg ~ ., data = mtcars) %>%
step_poly(drat, degree = tune("drat_deg")) %>%
step_poly(disp, degree = tune("disp_deg")) %>%
step_poly(hp, degree = tune("hp_deg"))
lm_spec <- linear_reg(penalty = 0, mixture = 0) %>% set_engine("lm")
lm_wf <-
workflow() %>%
add_recipe(lm_rec) %>%
add_model(lm_spec)
lm_grid <-
parameters(lm_rec) %>%
grid_regular() %>%
slice(8)
lm_grid
# Gives incorrect output
lm_wf %>%
finalize_workflow(lm_grid) %>%
pull_workflow_preprocessor() %>%
prep(mtcars) %>%
bake(new_data = mtcars) %>%
glimpse()
# This works though
lm_rec %>%
merge(lm_grid) %>%
.$x %>%
pluck(1) %>%
prep(mtcars) %>%
bake(new_data = mtcars) %>%
glimpse()Here's a screenshot of the output - finalize_workflow() results in 3 poly terms for drat (instead of two) and 2 poly terms for disp (instead of 3). The when I merge the grid object into the recipe is correct, however.

I believe that the params are being applied in the sorted order of the recipe step names, rather than by matching on the name of the parameter.
Session Info:
> sessionInfo()
R version 3.4.1 (2017-06-30)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Red Hat Enterprise Linux
Matrix products: default
BLAS: /opt/microsoft/mlserver/9.2.1/runtime/R/lib/libRblas.so
LAPACK: /opt/microsoft/mlserver/9.2.1/runtime/R/lib/libRlapack.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] reprex_0.3.0 workflows_0.2.0 tune_0.1.1 dials_0.0.9 scales_1.1.1 parsnip_0.1.3
[7] recipes_0.1.13 purrr_0.3.4 dplyr_1.0.2 NAVpackrat_0.5.5
loaded via a namespace (and not attached):
[1] Rcpp_1.0.5 lubridate_1.7.9 lattice_0.20-38 tidyr_1.1.2 ps_1.3.4
[6] class_7.3-14 digest_0.6.25 assertthat_0.2.1 RevoUtilsMath_10.0.0 packrat_0.5.0
[11] ipred_0.9-9 foreach_1.5.0 utf8_1.1.4 R6_2.4.1 plyr_1.8.6
[16] evaluate_0.14 ggplot2_3.3.2 pillar_1.4.6 rlang_0.4.7 rstudioapi_0.11
[21] callr_3.4.4 whisker_0.4 DiceDesign_1.8-1 rpart_4.1-13 Matrix_1.2-10
[26] rmarkdown_2.3 devtools_1.13.3 splines_3.4.1 RevoUtils_10.0.5 gower_0.2.2
[31] munsell_0.5.0 xfun_0.17 compiler_3.4.1 pkgconfig_2.0.3 clipr_0.7.0
[36] htmltools_0.5.0 nnet_7.3-12 tidyselect_1.1.0 tibble_3.0.3 mrupdate_1.0.1
[41] prodlim_2019.11.13 codetools_0.2-15 GPfit_1.0-8 fansi_0.4.1 crayon_1.3.4
[46] withr_2.2.0 MASS_7.3-47 grid_3.4.1 jsonlite_1.7.1 gtable_0.3.0
[51] lifecycle_0.2.0 magrittr_1.5 pROC_1.16.2 cli_2.0.2 fs_1.5.0
[56] timeDate_3043.102 ellipsis_0.3.1 lhs_1.0.2 generics_0.0.2 vctrs_0.3.4
[61] lava_1.6.7 iterators_1.0.12 yardstick_0.0.7 tools_3.4.1 glue_1.4.2
[66] processx_3.4.4 survival_2.42-6 yaml_2.2.1 colorspace_1.4-1 memoise_1.1.0
[71] knitr_1.29