Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

step_embed not working with parallel processing in caret #21

Closed
tmastny opened this issue Aug 26, 2019 · 3 comments
Closed

step_embed not working with parallel processing in caret #21

tmastny opened this issue Aug 26, 2019 · 3 comments

Comments

@tmastny
Copy link
Contributor

tmastny commented Aug 26, 2019

Hello, step_embed does not seem to work with parallel processing in caret. I think it may be related to topepo/caret#860

I am getting this error, which is very similar to the error in the previous issue:

Error in {: task 1 failed - "$ operator is invalid for atomic vectors"

Here is a reproducible example:

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step
library(embed)
#> Registered S3 method overwritten by 'xts':
#>   method     from
#>   as.zoo.xts zoo
library(doParallel)
#> Loading required package: foreach
#> Loading required package: iterators
#> Loading required package: parallel

sessionInfo()
#> R version 3.6.0 (2019-04-26)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] parallel  stats     graphics  grDevices utils     datasets  methods  
#> [8] base     
#> 
#> other attached packages:
#> [1] doParallel_1.0.15 iterators_1.0.12  foreach_1.4.7     embed_0.0.3      
#> [5] recipes_0.1.6     dplyr_0.8.3       caret_6.0-84      ggplot2_3.2.1    
#> [9] lattice_0.20-38  
#> 
#> loaded via a namespace (and not attached):
#>   [1] minqa_1.2.4           colorspace_1.4-1      class_7.3-15         
#>   [4] ggridges_0.5.1        rsconnect_0.8.15      markdown_1.0         
#>   [7] base64enc_0.1-3       rstan_2.19.2          DT_0.7               
#>  [10] prodlim_2018.04.18    lubridate_1.7.4       codetools_0.2-16     
#>  [13] splines_3.6.0         knitr_1.23            shinythemes_1.1.2    
#>  [16] zeallot_0.1.0         bayesplot_1.7.0       jsonlite_1.6         
#>  [19] nloptr_1.2.1          tfruns_1.4            uwot_0.1.3           
#>  [22] shiny_1.3.2           compiler_3.6.0        backports_1.1.4      
#>  [25] assertthat_0.2.1      Matrix_1.2-17         lazyeval_0.2.2       
#>  [28] cli_1.1.0             later_0.8.0           htmltools_0.3.6      
#>  [31] prettyunits_1.0.2     tools_3.6.0           igraph_1.2.4.1       
#>  [34] gtable_0.3.0          glue_1.3.1            reshape2_1.4.3       
#>  [37] Rcpp_1.0.2            vctrs_0.2.0           nlme_3.1-139         
#>  [40] crosstalk_1.0.0       timeDate_3043.102     gower_0.2.1          
#>  [43] xfun_0.8              stringr_1.4.0         ps_1.3.0             
#>  [46] lme4_1.1-21           lifecycle_0.1.0       mime_0.7             
#>  [49] miniUI_0.1.1.1        gtools_3.8.1          MASS_7.3-51.4        
#>  [52] zoo_1.8-6             scales_1.0.0          ipred_0.9-9          
#>  [55] rstanarm_2.18.2       colourpicker_1.0      promises_1.0.1       
#>  [58] inline_0.3.15         shinystan_2.5.0       yaml_2.2.0           
#>  [61] reticulate_1.13       gridExtra_2.3         loo_2.1.0            
#>  [64] StanHeaders_2.18.1-10 keras_2.2.4.1         rpart_4.1-15         
#>  [67] stringi_1.4.3         highr_0.8             tensorflow_1.13.1    
#>  [70] dygraphs_1.1.1.6      boot_1.3-22           pkgbuild_1.0.3       
#>  [73] lava_1.6.6            rlang_0.4.0           pkgconfig_2.0.2      
#>  [76] matrixStats_0.54.0    evaluate_0.14         purrr_0.3.2          
#>  [79] rstantools_1.5.1      htmlwidgets_1.3       tidyselect_0.2.5     
#>  [82] processx_3.4.1        plyr_1.8.4            magrittr_1.5.0.9000  
#>  [85] R6_2.4.0              generics_0.0.2        pillar_1.4.2         
#>  [88] whisker_0.3-2         withr_2.1.2           xts_0.11-2           
#>  [91] survival_2.44-1.1     nnet_7.3-12           tibble_2.1.3         
#>  [94] crayon_1.3.4          rmarkdown_1.13.6      grid_3.6.0           
#>  [97] data.table_1.12.2     callr_3.3.1           ModelMetrics_1.2.2   
#> [100] threejs_0.3.1         digest_0.6.20         xtable_1.8-4         
#> [103] tidyr_0.8.99.9000     httpuv_1.5.1          RcppParallel_4.4.3   
#> [106] stats4_3.6.0          munsell_0.5.0         shinyjs_1.0

mtcars2 <- as_tibble(mtcars)
mtcars2 <- mtcars2 %>%
  mutate(cyl = as.factor(paste0("num_", cyl))) %>%
  mutate(am = as.factor(ifelse(am == 1, "am", "not_am")))

rec <- recipe(am ~ cyl + hp, mtcars2) %>%
  step_embed(
    cyl,
    outcome = vars(am),
    options = embed_control(epochs = 75, validation_split = 0.2)
  )

ctrl <- trainControl(
  method = 'cv',
  number = 5,
  savePredictions = 'final',
  classProbs = TRUE,
  summaryFunction = twoClassSummary,
  sampling = NULL,
  returnData = FALSE
)

cl <- makePSOCKcluster(4)
registerDoParallel(cl)

train(
  rec,
  mtcars2,
  method = "glm",
  metric = "ROC",
  trControl = ctrl
)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 6913 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Error in {: task 1 failed - "$ operator is invalid for atomic vectors"

stopCluster(cl)

Created on 2019-08-26 by the reprex package (v0.3.0)

And here is the same example, working without parallel processing:

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step
library(embed)
#> Registered S3 method overwritten by 'xts':
#>   method     from
#>   as.zoo.xts zoo

sessionInfo()
#> R version 3.6.0 (2019-04-26)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] embed_0.0.3     recipes_0.1.6   dplyr_0.8.3     caret_6.0-84   
#> [5] ggplot2_3.2.1   lattice_0.20-38
#> 
#> loaded via a namespace (and not attached):
#>   [1] minqa_1.2.4           colorspace_1.4-1      class_7.3-15         
#>   [4] ggridges_0.5.1        rsconnect_0.8.15      markdown_1.0         
#>   [7] base64enc_0.1-3       rstan_2.19.2          DT_0.7               
#>  [10] prodlim_2018.04.18    lubridate_1.7.4       codetools_0.2-16     
#>  [13] splines_3.6.0         knitr_1.23            shinythemes_1.1.2    
#>  [16] zeallot_0.1.0         bayesplot_1.7.0       jsonlite_1.6         
#>  [19] nloptr_1.2.1          tfruns_1.4            uwot_0.1.3           
#>  [22] shiny_1.3.2           compiler_3.6.0        backports_1.1.4      
#>  [25] assertthat_0.2.1      Matrix_1.2-17         lazyeval_0.2.2       
#>  [28] cli_1.1.0             later_0.8.0           htmltools_0.3.6      
#>  [31] prettyunits_1.0.2     tools_3.6.0           igraph_1.2.4.1       
#>  [34] gtable_0.3.0          glue_1.3.1            reshape2_1.4.3       
#>  [37] Rcpp_1.0.2            vctrs_0.2.0           nlme_3.1-139         
#>  [40] iterators_1.0.12      crosstalk_1.0.0       timeDate_3043.102    
#>  [43] gower_0.2.1           xfun_0.8              stringr_1.4.0        
#>  [46] ps_1.3.0              lme4_1.1-21           lifecycle_0.1.0      
#>  [49] mime_0.7              miniUI_0.1.1.1        gtools_3.8.1         
#>  [52] MASS_7.3-51.4         zoo_1.8-6             scales_1.0.0         
#>  [55] ipred_0.9-9           rstanarm_2.18.2       colourpicker_1.0     
#>  [58] promises_1.0.1        parallel_3.6.0        inline_0.3.15        
#>  [61] shinystan_2.5.0       yaml_2.2.0            reticulate_1.13      
#>  [64] gridExtra_2.3         loo_2.1.0             StanHeaders_2.18.1-10
#>  [67] keras_2.2.4.1         rpart_4.1-15          stringi_1.4.3        
#>  [70] highr_0.8             tensorflow_1.13.1     dygraphs_1.1.1.6     
#>  [73] foreach_1.4.7         boot_1.3-22           pkgbuild_1.0.3       
#>  [76] lava_1.6.6            rlang_0.4.0           pkgconfig_2.0.2      
#>  [79] matrixStats_0.54.0    evaluate_0.14         purrr_0.3.2          
#>  [82] rstantools_1.5.1      htmlwidgets_1.3       tidyselect_0.2.5     
#>  [85] processx_3.4.1        plyr_1.8.4            magrittr_1.5.0.9000  
#>  [88] R6_2.4.0              generics_0.0.2        pillar_1.4.2         
#>  [91] whisker_0.3-2         withr_2.1.2           xts_0.11-2           
#>  [94] survival_2.44-1.1     nnet_7.3-12           tibble_2.1.3         
#>  [97] crayon_1.3.4          rmarkdown_1.13.6      grid_3.6.0           
#> [100] data.table_1.12.2     callr_3.3.1           ModelMetrics_1.2.2   
#> [103] threejs_0.3.1         digest_0.6.20         xtable_1.8-4         
#> [106] tidyr_0.8.99.9000     httpuv_1.5.1          RcppParallel_4.4.3   
#> [109] stats4_3.6.0          munsell_0.5.0         shinyjs_1.0

mtcars2 <- as_tibble(mtcars)
mtcars2 <- mtcars2 %>%
  mutate(cyl = as.factor(paste0("num_", cyl))) %>%
  mutate(am = as.factor(ifelse(am == 1, "am", "not_am")))

rec <- recipe(am ~ cyl + hp, mtcars2) %>%
  step_embed(
    cyl,
    outcome = vars(am),
    options = embed_control(epochs = 75, validation_split = 0.2)
  )

ctrl <- trainControl(
  method = 'cv',
  number = 5,
  savePredictions = 'final',
  classProbs = TRUE,
  summaryFunction = twoClassSummary,
  sampling = NULL,
  returnData = FALSE
)

train(
  rec,
  mtcars2,
  method = "glm",
  metric = "ROC",
  trControl = ctrl
)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 8761 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 7367 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 1120 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 3630 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 6134 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 6598 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Set session seed to 8840 (disabled GPU, CPU parallelism)
#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?

#> Warning: All elements of `...` must be named.
#> Did you want `data = c(type, role, source)`?
#> Generalized Linear Model 
#> 
#> Recipe steps: embed 
#> Resampling: Cross-Validated (5 fold) 
#> Summary of sample sizes: 26, 25, 25, 25, 27 
#> Resampling results:
#> 
#>   ROC    Sens  Spec     
#>   0.775  0.7   0.7333333

Created on 2019-08-26 by the reprex package (v0.3.0)

@tmastny
Copy link
Contributor Author

tmastny commented Aug 26, 2019

After more investigating, I believe the error originates here:
https://github.com/topepo/caret/blob/c69d61dd7c968b50a87fb2f337424d08ea4a980a/pkg/caret/R/train_recipes.R#L412-L413

When I debug into that line, I see that pkgs is

[1] "methods" "caret"   "recipes"

which I think means that foreach doesn't have access to embed. From ?foreach:

.packages | character vector of packages that the tasks depend on. If exrequires a R package to be loaded, this option can be used to load that package on each of the workers. Ignored when used with %do%

So I think this issue is probably a caret problem.

@topepo
Copy link
Member

topepo commented Aug 28, 2019

tensorflow models can't really be parallelized in the way that caret does with other models/preprocessing methods. For models using TF/keras, a check is done and the user is prohibited from running these in (non-gpu) parallel.

There's no real way to do this for recipes but I will add a note in the embed help files about this limitation.

@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 6, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants