Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using rerank = TRUE in rfeControl using rfFuncs #942

Closed
and-jonas opened this issue Sep 24, 2018 · 1 comment
Closed

Error when using rerank = TRUE in rfeControl using rfFuncs #942

and-jonas opened this issue Sep 24, 2018 · 1 comment

Comments

@and-jonas
Copy link

Hi,

I want to use rerank = TRUE in rfeControl.
Therefore, in rfFuncs$fit I set importance = TRUE.

At the end, I get:
Error in { : task 1 failed - "replacement has 1 row, data has 0"

It looks like here: #543
but I could not solve the problem using help provided there.

Any ideas? Thank you!

Minimal dataset:

data <- structure(list(var1 = c(440.59, 454.24, 416.38, 482.57, 418.7, 
                                333.78, 443.74, 491.32, 456.27, 434.27, 381.45, 495.49, 455.87, 
                                478.33, 468.03, 435.53, 463.55, 430.28, 395.33, 382.79, 470.37, 
                                372.13, 484.36, 387.55, 507.5, 484.46, 459.93, 442.99, 380.4, 
                                475.33, 496.01, 421.69, 512.15, 475.5, 460.19, 392.09, 447.71, 
                                393.12, 390.58, 439.78, 375.01, 458.71, 399.3, 417.32, 491.85, 
                                467.28, 439.73, 481.09, 458.89, 432.39), var2 = c(447.82, 529.54, 
                                                                                  490.07, 549.5, 381.26, 450.87, 530.59, 555.85, 545.82, 458.93, 
                                                                                  514.91, 554.22, 518.94, 528.15, 495.04, 566.62, 471.33, 570.04, 
                                                                                  470.19, 482.26, 539.63, 445.02, 573.07, 519.23, 589.06, 568.69, 
                                                                                  464.69, 564.87, 451.9, 576.05, 391.76, 469.55, 526.91, 544, 500.8, 
                                                                                  461.4, 541.59, 500.22, 447.9, 493.13, 532.88, 574.8, 469.89, 
                                                                                  504.82, 521.22, 560.03, 502.26, 558.53, 510.04, 523.54), var3 = c(444, 
                                                                                                                                                    461.06, 560.69, 581.73, 400.02, 472.99, 455.89, 587.76, 538.19, 
                                                                                                                                                    484.92, 488.86, 556.57, 619.19, 506.69, 572.85, 550.2, 460.49, 
                                                                                                                                                    608.51, 489.55, 546.56, 557.11, 528.99, 528.63, 530.05, 621.3, 
                                                                                                                                                    541.67, 509.11, 521.47, 445.03, 615.56, 474.33, 581.92, 543.26, 
                                                                                                                                                    541.32, 540.45, 483.97, 359.09, 536.44, 480.41, 544.43, 496.73, 
                                                                                                                                                    530.64, 525.84, 486.07, 561.89, 551.39, 529.03, 584.78, 529.72, 
                                                                                                                                                    544.47), var4 = c(510.73, 541.79, 541.69, 576.08, 471.35, 483.9, 
                                                                                                                                                                      574.37, 567.08, 543.56, 515.72, 503.09, 572.73, 555.06, 538.8, 
                                                                                                                                                                      490, 522.21, 504.2, 552.13, 496.06, 524.55, 544.28, 569.36, 567.47, 
                                                                                                                                                                      530.92, 609.23, 582.03, 502.22, 545.05, 499.96, 573.47, 586.27, 
                                                                                                                                                                      552.11, 583.24, 567.69, 599.15, 498.31, 595.78, 532.17, 470.2, 
                                                                                                                                                                      564.76, 522.98, 564.19, 490.34, 490.51, 567.32, 563.1, 535.02, 
                                                                                                                                                                      542.82, 540.52, 551.64), var5 = c(485.48, 468.71, 441.72, 515.3, 
                                                                                                                                                                                                        464.79, 379.61, 486.36, 524.52, 494.43, 467.36, 435.94, 516.92, 
                                                                                                                                                                                                        517.81, 483.98, 468.43, 448.72, 475.7, 446.31, 466.88, 438.12, 
                                                                                                                                                                                                        494.97, 415.71, 518.69, 437.63, 541.35, 503.54, 491.77, 488.64, 
                                                                                                                                                                                                        417.34, 567.54, 527.06, 477.82, 513.75, 519.73, 486.93, 422.77, 
                                                                                                                                                                                                        516.71, 443.31, 410.17, 493.06, 433.67, 495, 417.87, 430.05, 
                                                                                                                                                                                                        540, 503.55, 481.59, 512.53, 484.05, 456.1), var6 = c(444, 461.06, 
                                                                                                                                                                                                                                                              560.69, 581.73, 400.02, 472.99, 455.89, 587.76, 538.19, 484.92, 
                                                                                                                                                                                                                                                              488.86, 556.57, 619.19, 506.69, 572.85, 550.2, 460.49, 608.51, 
                                                                                                                                                                                                                                                              489.55, 546.56, 557.11, 528.99, 528.63, 530.05, 621.3, 541.67, 
                                                                                                                                                                                                                                                              509.11, 521.47, 445.03, 615.56, 474.33, 581.92, 543.26, 541.32, 
                                                                                                                                                                                                                                                              540.45, 483.97, 359.09, 536.44, 480.41, 544.43, 496.73, 530.64, 
                                                                                                                                                                                                                                                              525.84, 486.07, 561.89, 551.39, 529.03, 584.78, 529.72, 544.47
                                                                                                                                                                                                        ), var7 = c(496.39, 498.73, 486.09, 532.65, 465.05, 431.33, 521.97, 
                                                                                                                                                                                                                    533.82, 532.72, 493.37, 472.2, 538.3, 541.75, 497.99, 480.41, 
                                                                                                                                                                                                                    488.62, 491.22, 504.81, 481.64, 497.05, 510.97, 503.75, 536.8, 
                                                                                                                                                                                                                    489.68, 596.18, 538.31, 494.15, 511.7, 453.99, 571.81, 547.25, 
                                                                                                                                                                                                                    528.76, 536.13, 549.71, 546.35, 447.79, 554.71, 487.59, 441.03, 
                                                                                                                                                                                                                    541.01, 477.99, 511.94, 446.64, 461.2, 539.42, 537.35, 501.15, 
                                                                                                                                                                                                                    514.02, 508.43, 508.46), var8 = c(457.14, 500.02, 470.58, 539.01, 
                                                                                                                                                                                                                                                      464.7, 378.35, 493.35, 523.03, 510.84, 454.19, 415.43, 521.59, 
                                                                                                                                                                                                                                                      549.99, 503.25, 477.39, 470.67, 463.54, 453.28, 430.14, 479.98, 
                                                                                                                                                                                                                                                      504.24, 415.54, 496.72, 455.69, 505.7, 518.49, 475.23, 477.16, 
                                                                                                                                                                                                                                                      459.61, 592.19, 566.81, 499.43, 525.1, 486.8, 499.59, 392, 490.81, 
                                                                                                                                                                                                                                                      483.33, 461.79, 454.72, 427.26, 510.78, 442.76, 442.97, 527.41, 
                                                                                                                                                                                                                                                      491.17, 493.89, 504.11, 459.74, 464.43), var9 = c(560.45, 601.68, 
                                                                                                                                                                                                                                                                                                        548.03, 631.11, 490.05, 488.21, 575.58, 637.52, 607.51, 574.68, 
                                                                                                                                                                                                                                                                                                        562.69, 625.5, 604.74, 592.41, 593.14, 688.23, 598.56, 656.48, 
                                                                                                                                                                                                                                                                                                        403.25, 608.19, 613.85, 558.05, 557.01, 586.56, 688.48, 639.31, 
                                                                                                                                                                                                                                                                                                        618.84, 565.86, 530.69, 586.07, 629.82, 708.42, 614.81, 601.63, 
                                                                                                                                                                                                                                                                                                        572.11, 509.94, 682, 566.84, 596.62, 550.1, 509.24, 639.61, 571.13, 
                                                                                                                                                                                                                                                                                                        549.11, 625.37, 639.66, 597.01, 639.97, 563.66, 619.61), var10 = c(638.06, 
                                                                                                                                                                                                                                                                                                                                                                           591.63, 588, 611.45, 467.03, 571.62, 609.37, 633.13, 611.21, 
                                                                                                                                                                                                                                                                                                                                                                           572.62, 653.49, 632.63, 598.95, 595.31, 622.3, 663.32, 599.18, 
                                                                                                                                                                                                                                                                                                                                                                           686.13, 556.94, 664.56, 572.48, 679.93, 625.19, 698.98, 713.04, 
                                                                                                                                                                                                                                                                                                                                                                           620.76, 540.78, 651.8, 533.99, 638.79, 609.49, 714.42, 628.05, 
                                                                                                                                                                                                                                                                                                                                                                           597.43, 610.64, 516.04, 671.79, 598.74, 558.34, 607.58, 584.2, 
                                                                                                                                                                                                                                                                                                                                                                           634.32, 535.16, 542.7, 616.54, 675.02, 610.78, 629.06, 590.41, 
                                                                                                                                                                                                                                                                                                                                                                           618.39), var11 = c(487.31, 473.1, 481.22, 520.67, 431.87, 421.47, 
                                                                                                                                                                                                                                                                                                                                                                                              509.64, 532.78, 454.39, 471.5, 452.32, 522.56, 518.12, 495.63, 
                                                                                                                                                                                                                                                                                                                                                                                              457.51, 466.29, 444.2, 472.23, 444.57, 457.18, 467.75, 505, 482.13, 
                                                                                                                                                                                                                                                                                                                                                                                              447.9, 537.81, 501.76, 496.21, 492.55, 479.21, 534.91, 403.1, 
                                                                                                                                                                                                                                                                                                                                                                                              489.7, 561.54, 501.77, 566.12, 466.33, 545.68, 454.73, 392.47, 
                                                                                                                                                                                                                                                                                                                                                                                              501.58, 471.96, 488.29, 479.05, 421.2, 500.54, 491.24, 490.11, 
                                                                                                                                                                                                                                                                                                                                                                                              478.09, 531.9, 430.16), var12 = c(511.33, 541.14, 629.86, 618.29, 
                                                                                                                                                                                                                                                                                                                                                                                                                                487.83, 554.07, 602.17, 651.43, 579.79, 510.29, 536.32, 613.74, 
                                                                                                                                                                                                                                                                                                                                                                                                                                623.75, 565.36, 579.78, 616.81, 551.34, 634.76, 472.55, 619.07, 
                                                                                                                                                                                                                                                                                                                                                                                                                                588.53, 586.81, 583.92, 616.7, 682.79, 603.31, 531.3, 596.27, 
                                                                                                                                                                                                                                                                                                                                                                                                                                516.1, 632.27, 598.21, 596.95, 619.49, 587.85, 530.98, 503.62, 
                                                                                                                                                                                                                                                                                                                                                                                                                                662.77, 601.36, 529.96, 573.91, 522.34, 601.81, 566.06, 546.01, 
                                                                                                                                                                                                                                                                                                                                                                                                                                587.28, 597.92, 572.5, 642.19, 571.43, 566.28), var13 = c(480.91, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          477.94, 506.87, 531.96, 448.02, 439.94, 533.62, 525.34, 461.53, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          490.76, 441.13, 515.98, 538.51, 505.74, 494.03, 508.24, 460.39, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          515.45, 446.25, 484.72, 483.76, 525.98, 496.28, 462.74, 553.51, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          544.76, 442.58, 495.08, 477.35, 569.84, 563.07, 507.48, 565.31, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          519.74, 571.7, 461.69, 585.58, 458.47, 450.19, 503.06, 462.84, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          500.92, 439.47, 431.2, 519.78, 474.72, 462.26, 497.18, 522.09, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          488.6), var14 = c(661.57, 683.23, 667.02, 687.32, 584.9, 655.7, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            694.56, 751.54, 673.66, 680.37, 664.52, 741.34, 711.15, 688.7, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            676.49, 665.07, 636.99, 713.37, 635.01, 702.29, 680.21, 690.08, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            667.89, 724.56, 767.23, 766.12, 716.75, 695.17, 659.01, 694.13, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            691.49, 721.73, 709.14, 684.07, 677.3, 636.14, 705.73, 676.02, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            672.58, 705.08, 668.27, 716.78, 729.22, 710.52, 679.64, 729.03, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            695.14, 714.42, 677.54, 661.78), var15 = c(587.36, 604.52, 587.99, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       598.13, 500.03, 593.37, 628.36, 645.44, 594.52, 564.22, 569.66, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       643.58, 619.32, 605.17, 554.83, 619.98, 557.51, 633.53, 546.87, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       592.16, 581.08, 644.17, 631.94, 611.84, 669.75, 640.75, 537.87, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       614.35, 537.43, 650.96, 621.44, 636.54, 648.42, 588.16, 660.62, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       548.63, 660.18, 590.05, 536.43, 596.55, 562.18, 611.03, 588.21, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       558.64, 624.76, 625.61, 600.49, 633.57, 576.51, 621), var16 = c(6.27, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       5.74, 6.29, 6.82, 5.83, 5.94, 8.22, 6.71, 5.75, 5.41, 6.59, 6.17, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       8.82, 5.36, 6.92, 6.01, 5.53, 5.44, 6.26, 6.54, 6.46, 6.45, 5.93, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       6.58, 6.91, 5.88, 6.47, 6.82, 5, 7.49, 5.65, 6.43, 6.45, 6.2, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       6.71, 5.18, 6.99, 6.2, 6.27, 6.93, 6.06, 7.83, 5.79, 6.9, 6.68, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       6.37, 5.95, 5.48, 6, 6.19)), .Names = c("var1", "var2", "var3", 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               "var4", "var5", "var6", "var7", "var8", "var9", "var10", "var11", 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               "var12", "var13", "var14", "var15", "var16"), row.names = c(NA, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           50L), class = "data.frame")

Minimal, runnable code:

x <- data %>% 
  select(-var16) %>% 
  as.matrix()

y <- data %>% 
  select(var16) %>% 
  as.matrix()  

rfFuncs <-  list(summary = defaultSummary,
                 fit = function(x, y, first, last, ...) {
                   loadNamespace("randomForest")
                   randomForest::randomForest(x, y, importance = TRUE, ntree = 10,...) #change to importance = TRUE for rerank = TRUE
                 },
                 pred = function(object, x)  {
                   tmp <- predict(object, x)
                   if(is.factor(object$y)) {
                     out <- cbind(data.frame(pred = tmp),
                                  as.data.frame(predict(object, x, type = "response"))) #change to response from class probalities
                   } else out <- tmp
                   out
                 },
                 rank = function(object, x, y) {
                   vimp <- varImp(object)
                   
                   if(is.factor(y)) {
                     if(all(levels(y) %in% colnames(vimp))) {
                       avImp <- apply(vimp[, levels(y), drop = TRUE], 1, mean)
                       vimp$Overall <- avImp
                     }
                   }
                   
                   vimp <- vimp[order(vimp$Overall, decreasing = TRUE),, drop = FALSE]
                   
                   vimp$var <- rownames(vimp)
                   vimp
                 },
                 selectSize = pickSizeBest,
                 selectVar = pickVars)


ctrl <- rfeControl(functions = rfFuncs,
                   method = "repeatedcv",
                   repeats = 2,
                   number = 3,
                   rerank = TRUE, #recalculate importance after each round
                   returnResamp = "final",
                   verbose = TRUE)

subsets <- c(1:7, 10)

rfProfile <- rfe(x, y, sizes = subsets, rfeControl = ctrl)

Session Info:

> sessionInfo()
R version 3.4.4 (2018-03-15)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)

Matrix products: default

locale:
[1] LC_COLLATE=German_Switzerland.1252  LC_CTYPE=German_Switzerland.1252    LC_MONETARY=German_Switzerland.1252 LC_NUMERIC=C                       
[5] LC_TIME=German_Switzerland.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] Hmisc_4.1-1         Formula_1.2-3       survival_2.41-3     mlbench_2.1-1       bindrcpp_0.2.2      dplyr_0.7.6         randomForest_4.6-14
 [8] caret_6.0-80        ggplot2_3.0.0       lattice_0.20-35    

loaded via a namespace (and not attached):
 [1] httr_1.3.1          magic_1.5-8         ddalpha_1.3.4       tidyr_0.8.1         sfsmisc_1.1-2       splines_3.4.4       foreach_1.4.4      
 [8] prodlim_2018.04.18  assertthat_0.2.0    stats4_3.4.4        latticeExtra_0.6-28 DRR_0.0.3           robustbase_0.93-1.1 ipred_0.9-6        
[15] pillar_1.3.0        backports_1.1.2     glue_1.3.0          digest_0.6.15       checkmate_1.8.5     RColorBrewer_1.1-2  colorspace_1.3-2   
[22] recipes_0.1.3       htmltools_0.3.6     Matrix_1.2-12       plyr_1.8.4          timeDate_3043.102   pkgconfig_2.0.1     devtools_1.13.6    
[29] CVST_0.2-2          broom_0.5.0         purrr_0.2.5         scales_0.5.0        gower_0.1.2         lava_1.6.2          htmlTable_1.12     
[36] git2r_0.23.0        tibble_1.4.2        withr_2.1.2         nnet_7.3-12         lazyeval_0.2.1      magrittr_1.5        crayon_1.3.4       
[43] memoise_1.1.0       nlme_3.1-137        MASS_7.3-49         foreign_0.8-69      dimRed_0.1.0        class_7.3-14        data.table_1.11.4  
[50] tools_3.4.4         stringr_1.3.1       kernlab_0.9-26      munsell_0.5.0       cluster_2.0.6       e1071_1.7-0         pls_2.6-0          
[57] compiler_3.4.4      RcppRoll_0.3.0      rlang_0.2.1         grid_3.4.4          iterators_1.0.10    rstudioapi_0.7      htmlwidgets_1.2    
[64] base64enc_0.1-3     geometry_0.3-6      gtable_0.2.0        ModelMetrics_1.1.0  codetools_0.2-15    abind_1.4-5         curl_3.2           
[71] reshape2_1.4.3      R6_2.2.2            gridExtra_2.3       lubridate_1.7.4     knitr_1.20          bindr_0.1.1         stringi_1.1.7      
[78] Rcpp_0.12.17        rpart_4.1-13        acepack_1.4.1       DEoptimR_1.0-8      tidyselect_0.2.4   
topepo added a commit that referenced this issue Nov 16, 2018
@topepo
Copy link
Owner

topepo commented Nov 16, 2018

rfFuncs and rfe, by default compute the class probabilities and RF importance scores. There's no need to modify rfFuncs (and I think that there is an error there).

Also, the y argument should not be a matrix. ?rfe says "y: a vector of training set outcomes (either numeric or factor)".

However, there are two bugs in rfFuncs that this example exposed.

Using the default rfFuncs and the updated version:

> library(caret)
> library(dplyr)
> 
> x <- data %>% 
+   select(-var16) %>% 
+   as.matrix()
> 
> y <- data %>% 
+   pull(var16) 
> 
> ctrl <- rfeControl(functions = rfFuncs,
+                    method = "repeatedcv",
+                    repeats = 2,
+                    number = 3,
+                    rerank = TRUE, #recalculate importance after each round
+                    returnResamp = "final",
+                    verbose = FALSE)
> 
> subsets <- c(1:7, 10)
> 
> rfProfile <- 
+   rfe(x, y, sizes = subsets, rfeControl = ctrl)
> rfProfile

Recursive feature selection

Outer resampling method: Cross-Validated (3 fold, repeated 2 times) 

Resampling performance over subset size:

 Variables   RMSE Rsquared    MAE RMSESD RsquaredSD  MAESD Selected
         1 0.8213   0.1164 0.6636 0.1767    0.07516 0.1297         
         2 0.7961   0.1069 0.6268 0.2012    0.09828 0.1360         
         3 0.7362   0.1472 0.5741 0.1583    0.12812 0.1120         
         4 0.7152   0.1564 0.5494 0.1531    0.13126 0.1252         
         5 0.6932   0.1937 0.5224 0.1611    0.16194 0.1171         
         6 0.6876   0.2089 0.5227 0.1764    0.17427 0.1319         
         7 0.6871   0.1891 0.5229 0.1523    0.14117 0.1062         
        10 0.6961   0.1703 0.5326 0.1476    0.13660 0.1117         
        15 0.6852   0.1815 0.5150 0.1470    0.14243 0.1093        *

The top 5 variables (out of 15):
   var11, var4, var7, var1, var12

I'm patching a number of bugs in the next day so you might want to wait a day or so before reinstalling.

@topepo topepo closed this as completed Nov 16, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants