step_embed not working with parallel processing in caret #21

tmastny opened this issue Aug 26, 2019

tmastny opened this issue Aug 26, 2019


tmastny commented Aug 26, 2019

Hello, step_embed does not seem to work with parallel processing in caret. I think it may be related to topepo/caret#860

I am getting this error, which is very similar to the error in the previous issue:

Error in {: task 1 failed - "$ operator is invalid for atomic vectors"

Here is a reproducible example:

#> Loading required package: lattice
#> Loading required package: ggplot2
#> Loading required package: dplyr
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>     filter, lag
#> The following objects are masked from 'package:base':
#>     intersect, setdiff, setequal, union
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#>     step
#> Registered S3 method overwritten by 'xts':
#>   method     from
#>   as.zoo.xts zoo
#> Loading required package: foreach
#> Loading required package: iterators
#> Loading required package: parallel

#> R version 3.6.0 (2019-04-26)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> attached base packages:
#> [1] parallel  stats     graphics  grDevices utils     datasets  methods  
#> [8] base     
#> other attached packages:
#> [1] doParallel_1.0.15 iterators_1.0.12  foreach_1.4.7     embed_0.0.3      
#> [5] recipes_0.1.6     dplyr_0.8.3       caret_6.0-84      ggplot2_3.2.1    
#> [9] lattice_0.20-38  
#> loaded via a namespace (and not attached):
#>   [1] minqa_1.2.4           colorspace_1.4-1      class_7.3-15         
#>   [4] ggridges_0.5.1        rsconnect_0.8.15      markdown_1.0         
#>   [7] base64enc_0.1-3       rstan_2.19.2          DT_0.7               
#>  [10] prodlim_2018.04.18    lubridate_1.7.4       codetools_0.2-16     
#>  [13] splines_3.6.0         knitr_1.23            shinythemes_1.1.2    
#>  [16] zeallot_0.1.0         bayesplot_1.7.0       jsonlite_1.6         
#>  [19] nloptr_1.2.1          tfruns_1.4            uwot_0.1.3           
#>  [22] shiny_1.3.2           compiler_3.6.0        backports_1.1.4      
#>  [25] assertthat_0.2.1      Matrix_1.2-17         lazyeval_0.2.2       
#>  [28] cli_1.1.0             later_0.8.0           htmltools_0.3.6      
#>  [31] prettyunits_1.0.2     tools_3.6.0           igraph_1.2.4.1       
#>  [34] gtable_0.3.0          glue_1.3.1            reshape2_1.4.3       
#>  [37] Rcpp_1.0.2            vctrs_0.2.0           nlme_3.1-139         
#>  [40] crosstalk_1.0.0       timeDate_3043.102     gower_0.2.1          
#>  [43] xfun_0.8              stringr_1.4.0         ps_1.3.0             
#>  [46] lme4_1.1-21           lifecycle_0.1.0       mime_0.7             
#>  [49] miniUI_0.1.1.1        gtools_3.8.1          MASS_7.3-51.4        
#>  [52] zoo_1.8-6             scales_1.0.0          ipred_0.9-9          
#>  [55] rstanarm_2.18.2       colourpicker_1.0      promises_1.0.1       
#>  [58] inline_0.3.15         shinystan_2.5.0       yaml_2.2.0           
#>  [61] reticulate_1.13       gridExtra_2.3         loo_2.1.0            
#>  [64] StanHeaders_2.18.1-10 keras_2.2.4.1         rpart_4.1-15         
#>  [67] stringi_1.4.3         highr_0.8             tensorflow_1.13.1    
#>  [70] dygraphs_1.1.1.6      boot_1.3-22           pkgbuild_1.0.3       
#>  [73] lava_1.6.6            rlang_0.4.0           pkgconfig_2.0.2      
#>  [76] matrixStats_0.54.0    evaluate_0.14         purrr_0.3.2          
#>  [79] rstantools_1.5.1      htmlwidgets_1.3       tidyselect_0.2.5     
#>  [82] processx_3.4.1        plyr_1.8.4            magrittr_1.5.0.9000  
#>  [85] R6_2.4.0              generics_0.0.2        pillar_1.4.2         
#>  [88] whisker_0.3-2         withr_2.1.2           xts_0.11-2           
#>  [91] survival_2.44-1.1     nnet_7.3-12           tibble_2.1.3         
#>  [94] crayon_1.3.4          rmarkdown_1.13.6      grid_3.6.0           
#>  [97] data.table_1.12.2     callr_3.3.1           ModelMetrics_1.2.2   
#> [100] threejs_0.3.1         digest_0.6.20         xtable_1.8-4         
#> [103] tidyr_0.8.99.9000     httpuv_1.5.1          RcppParallel_4.4.3   
#> [106] stats4_3.6.0          munsell_0.5.0         shinyjs_1.0

mtcars2 <- as_tibble(mtcars)
mtcars2 <- mtcars2 %>%
  mutate(cyl = as.factor(paste0("num_", cyl))) %>%
  mutate(am = as.factor(ifelse(am == 1, "am", "not_am")))

rec <- recipe(am ~ cyl + hp, mtcars2) %>%
    outcome = vars(am),
    options = embed_control(epochs = 75, validation_split = 0.2)

ctrl <- trainControl(
  method = 'cv',
  number = 5,
  savePredictions = 'final',
  classProbs = TRUE,
  summaryFunction = twoClassSummary,
  sampling = NULL,
  returnData = FALSE

cl <- makePSOCKcluster(4)

  method = "glm",
  metric = "ROC",
  trControl = ctrl
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 6913 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Error in {: task 1 failed - "$ operator is invalid for atomic vectors"


Created on 2019-08-26 by the reprex package (v0.3.0)

And here is the same example, working without parallel processing:

#> Loading required package: lattice
#> Loading required package: ggplot2
#> Loading required package: dplyr
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>     filter, lag
#> The following objects are masked from 'package:base':
#>     intersect, setdiff, setequal, union
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#>     step
#> Registered S3 method overwritten by 'xts':
#>   method     from
#>   as.zoo.xts zoo

#> R version 3.6.0 (2019-04-26)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> other attached packages:
#> [1] embed_0.0.3     recipes_0.1.6   dplyr_0.8.3     caret_6.0-84   
#> [5] ggplot2_3.2.1   lattice_0.20-38
#> loaded via a namespace (and not attached):
#>   [1] minqa_1.2.4           colorspace_1.4-1      class_7.3-15         
#>   [4] ggridges_0.5.1        rsconnect_0.8.15      markdown_1.0         
#>   [7] base64enc_0.1-3       rstan_2.19.2          DT_0.7               
#>  [10] prodlim_2018.04.18    lubridate_1.7.4       codetools_0.2-16     
#>  [13] splines_3.6.0         knitr_1.23            shinythemes_1.1.2    
#>  [16] zeallot_0.1.0         bayesplot_1.7.0       jsonlite_1.6         
#>  [19] nloptr_1.2.1          tfruns_1.4            uwot_0.1.3           
#>  [22] shiny_1.3.2           compiler_3.6.0        backports_1.1.4      
#>  [25] assertthat_0.2.1      Matrix_1.2-17         lazyeval_0.2.2       
#>  [28] cli_1.1.0             later_0.8.0           htmltools_0.3.6      
#>  [31] prettyunits_1.0.2     tools_3.6.0           igraph_1.2.4.1       
#>  [34] gtable_0.3.0          glue_1.3.1            reshape2_1.4.3       
#>  [37] Rcpp_1.0.2            vctrs_0.2.0           nlme_3.1-139         
#>  [40] iterators_1.0.12      crosstalk_1.0.0       timeDate_3043.102    
#>  [43] gower_0.2.1           xfun_0.8              stringr_1.4.0        
#>  [46] ps_1.3.0              lme4_1.1-21           lifecycle_0.1.0      
#>  [49] mime_0.7              miniUI_0.1.1.1        gtools_3.8.1         
#>  [52] MASS_7.3-51.4         zoo_1.8-6             scales_1.0.0         
#>  [55] ipred_0.9-9           rstanarm_2.18.2       colourpicker_1.0     
#>  [58] promises_1.0.1        parallel_3.6.0        inline_0.3.15        
#>  [61] shinystan_2.5.0       yaml_2.2.0            reticulate_1.13      
#>  [64] gridExtra_2.3         loo_2.1.0             StanHeaders_2.18.1-10
#>  [67] keras_2.2.4.1         rpart_4.1-15          stringi_1.4.3        
#>  [70] highr_0.8             tensorflow_1.13.1     dygraphs_1.1.1.6     
#>  [73] foreach_1.4.7         boot_1.3-22           pkgbuild_1.0.3       
#>  [76] lava_1.6.6            rlang_0.4.0           pkgconfig_2.0.2      
#>  [79] matrixStats_0.54.0    evaluate_0.14         purrr_0.3.2          
#>  [82] rstantools_1.5.1      htmlwidgets_1.3       tidyselect_0.2.5     
#>  [85] processx_3.4.1        plyr_1.8.4            magrittr_1.5.0.9000  
#>  [88] R6_2.4.0              generics_0.0.2        pillar_1.4.2         
#>  [91] whisker_0.3-2         withr_2.1.2           xts_0.11-2           
#>  [94] survival_2.44-1.1     nnet_7.3-12           tibble_2.1.3         
#>  [97] crayon_1.3.4          rmarkdown_1.13.6      grid_3.6.0           
#> [100] data.table_1.12.2     callr_3.3.1           ModelMetrics_1.2.2   
#> [103] threejs_0.3.1         digest_0.6.20         xtable_1.8-4         
#> [106] tidyr_0.8.99.9000     httpuv_1.5.1          RcppParallel_4.4.3   
#> [109] stats4_3.6.0          munsell_0.5.0         shinyjs_1.0

mtcars2 <- as_tibble(mtcars)
mtcars2 <- mtcars2 %>%
  mutate(cyl = as.factor(paste0("num_", cyl))) %>%
  mutate(am = as.factor(ifelse(am == 1, "am", "not_am")))

rec <- recipe(am ~ cyl + hp, mtcars2) %>%
    outcome = vars(am),
    options = embed_control(epochs = 75, validation_split = 0.2)

ctrl <- trainControl(
  method = 'cv',
  number = 5,
  savePredictions = 'final',
  classProbs = TRUE,
  summaryFunction = twoClassSummary,
  sampling = NULL,
  returnData = FALSE

  method = "glm",
  metric = "ROC",
  trControl = ctrl
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 8761 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 7367 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 1120 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 3630 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 6134 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 6598 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 8840 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Generalized Linear Model 
#> Recipe steps: embed 
#> Resampling: Cross-Validated (5 fold) 
#> Summary of sample sizes: 26, 25, 25, 25, 27 
#> Resampling results:
#>   ROC    Sens  Spec     
#>   0.775  0.7   0.7333333

Created on 2019-08-26 by the reprex package (v0.3.0)

Contributor Author

tmastny commented Aug 26, 2019

After more investigating, I believe the error originates here:

When I debug into that line, I see that pkgs is

[1] "methods" "caret"   "recipes"

which I think means that foreach doesn't have access to embed. From ?foreach:

.packages | character vector of packages that the tasks depend on. If exrequires a R package to be loaded, this option can be used to load that package on each of the workers. Ignored when used with %do%

So I think this issue is probably a caret problem.

topepo commented Aug 28, 2019

tensorflow models can't really be parallelized in the way that caret does with other models/preprocessing methods. For models using TF/keras, a check is done and the user is prohibited from running these in (non-gpu) parallel.

There's no real way to do this for recipes but I will add a note in the embed help files about this limitation.

github-actions bot commented Mar 6, 2021

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 6, 2021
