Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference between formula and "x=, y=" in train() #913

Closed
proto4426 opened this issue Jul 12, 2018 · 3 comments
Closed

Difference between formula and "x=, y=" in train() #913

proto4426 opened this issue Jul 12, 2018 · 3 comments

Comments

@proto4426
Copy link

proto4426 commented Jul 12, 2018

Hi,
I thought it was equivalent to express the input data in train() either as a formula (y ~ x, data=) or as y=y, x=x but I just found a difference when retrieving importances for the rf method.

  • If train() is expressed as a formula, then the importances will be given by factors' levels of the categorical variable.
  • If train() is NOT expressed as formula, then the importances will be given for the whole factor variable, as expected.

The package randomForest, from which train() is based for this method, does not allow to express as a formula and only give the overall importances.

Where does this difference come from ?

library(tidyverse)
data("chickwts")
# Add a numeric predictor in the data
dt <- data.frame(chickwts, xx = rnorm(length(chickwts$feed)))

rfCARET <- caret::train(
  y          = dt$weight,
  x          = dt %>% select(-weight),
  method     = "rf",
  importance = T
)
caret::varImp(rfCARET$finalModel)
#>         Overall
#> feed 46.4034574
#> xx   -0.4728728

rfCARET_form <- caret::train(
  weight ~ .,
  data       = dt,
  method     = "rf",
  importance = T
)
caret::varImp(rfCARET_form$finalModel)
#>                  Overall
#> feedhorsebean 29.3780596
#> feedlinseed   18.8974110
#> feedmeatmeal   1.8305809
#> feedsoybean   11.5654634
#> feedsunflower 13.5071325
#> xx            -0.6113386



## Try with RandomForest package. 

library(randomForest)
rf <- randomForest(x = dt %>% select(-weight),
                   y = dt$weight)
importance(rf)
#>      IncNodePurity
#> feed      213583.0
#> xx        129131.1
Session info
devtools::session_info()
#> Session info -------------------------------------------------------------
#>  setting  value                       
#>  version  R version 3.4.2 (2017-09-28)
#>  system   x86_64, mingw32             
#>  ui       RTerm                       
#>  language (EN)                        
#>  collate  English_United States.1252  
#>  tz       Europe/Paris                
#>  date     2018-07-12
#> Packages -----------------------------------------------------------------
#>  package      * version  date       source        
#>  abind          1.4-5    2016-07-21 CRAN (R 3.4.1)
#>  assertthat     0.2.0    2017-04-11 CRAN (R 3.4.2)
#>  backports      1.1.2    2017-12-13 CRAN (R 3.4.3)
#>  base         * 3.4.2    2017-09-28 local         
#>  bindr          0.1.1    2018-03-13 CRAN (R 3.4.4)
#>  bindrcpp       0.2.2    2018-03-29 CRAN (R 3.4.4)
#>  broom          0.4.4    2018-03-29 CRAN (R 3.4.4)
#>  caret        * 6.0-79   2018-03-29 CRAN (R 3.4.4)
#>  cellranger     1.1.0    2016-07-27 CRAN (R 3.4.2)
#>  class          7.3-14   2015-08-30 CRAN (R 3.4.2)
#>  cli            1.0.0    2017-11-05 CRAN (R 3.4.2)
#>  codetools      0.2-15   2016-10-05 CRAN (R 3.4.2)
#>  colorspace     1.3-2    2016-12-14 CRAN (R 3.4.2)
#>  compiler       3.4.2    2017-09-28 local         
#>  crayon         1.3.4    2017-09-16 CRAN (R 3.4.2)
#>  CVST           0.2-1    2013-12-10 CRAN (R 3.4.2)
#>  datasets     * 3.4.2    2017-09-28 local         
#>  ddalpha        1.3.2    2018-04-08 CRAN (R 3.4.2)
#>  DEoptimR       1.0-8    2016-11-19 CRAN (R 3.4.1)
#>  devtools       1.13.5   2018-02-18 CRAN (R 3.4.3)
#>  digest         0.6.15   2018-01-28 CRAN (R 3.4.3)
#>  dimRed         0.1.0    2017-05-04 CRAN (R 3.4.2)
#>  dplyr        * 0.7.4    2017-09-28 CRAN (R 3.4.2)
#>  DRR            0.0.3    2018-01-06 CRAN (R 3.4.3)
#>  evaluate       0.10.1   2017-06-24 CRAN (R 3.4.2)
#>  forcats      * 0.3.0    2018-02-19 CRAN (R 3.4.3)
#>  foreach        1.4.4    2017-12-12 CRAN (R 3.4.3)
#>  foreign        0.8-69   2017-06-22 CRAN (R 3.4.2)
#>  geometry       0.3-6    2015-09-09 CRAN (R 3.4.4)
#>  ggplot2      * 2.2.1    2016-12-30 CRAN (R 3.4.2)
#>  glue           1.2.0    2017-10-29 CRAN (R 3.4.2)
#>  gower          0.1.2    2017-02-23 CRAN (R 3.4.2)
#>  graphics     * 3.4.2    2017-09-28 local         
#>  grDevices    * 3.4.2    2017-09-28 local         
#>  grid           3.4.2    2017-09-28 local         
#>  gtable         0.2.0    2016-02-26 CRAN (R 3.4.2)
#>  haven          1.1.1    2018-01-18 CRAN (R 3.4.3)
#>  hms            0.4.2    2018-03-10 CRAN (R 3.4.3)
#>  htmltools      0.3.6    2017-04-28 CRAN (R 3.4.2)
#>  httr           1.3.1    2017-08-20 CRAN (R 3.4.2)
#>  ipred          0.9-6    2017-03-01 CRAN (R 3.4.2)
#>  iterators      1.0.9    2017-12-12 CRAN (R 3.4.3)
#>  jsonlite       1.5      2017-06-01 CRAN (R 3.4.2)
#>  kernlab        0.9-25   2016-10-03 CRAN (R 3.4.1)
#>  knitr          1.20     2018-02-20 CRAN (R 3.4.3)
#>  lattice      * 0.20-35  2017-03-25 CRAN (R 3.4.2)
#>  lava           1.6.1    2018-03-28 CRAN (R 3.4.4)
#>  lazyeval       0.2.1    2017-10-29 CRAN (R 3.4.2)
#>  lubridate      1.7.3    2018-02-27 CRAN (R 3.4.2)
#>  magic          1.5-8    2018-01-26 CRAN (R 3.4.3)
#>  magrittr       1.5      2014-11-22 CRAN (R 3.4.2)
#>  MASS           7.3-50   2018-04-30 CRAN (R 3.4.4)
#>  Matrix         1.2-11   2017-08-21 CRAN (R 3.4.2)
#>  memoise        1.1.0    2017-04-21 CRAN (R 3.4.2)
#>  methods      * 3.4.2    2017-09-28 local         
#>  mnormt         1.5-5    2016-10-15 CRAN (R 3.4.1)
#>  ModelMetrics   1.1.0    2016-08-26 CRAN (R 3.4.2)
#>  modelr         0.1.1    2017-07-24 CRAN (R 3.4.2)
#>  munsell        0.4.3    2016-02-13 CRAN (R 3.4.2)
#>  nlme           3.1-131  2017-02-06 CRAN (R 3.4.2)
#>  nnet           7.3-12   2016-02-02 CRAN (R 3.4.2)
#>  parallel       3.4.2    2017-09-28 local         
#>  pillar         1.2.1    2018-02-27 CRAN (R 3.4.3)
#>  pkgconfig      2.0.1    2017-03-21 CRAN (R 3.4.2)
#>  plyr           1.8.4    2016-06-08 CRAN (R 3.4.2)
#>  prodlim        1.6.1    2017-03-06 CRAN (R 3.4.2)
#>  psych          1.8.3.3  2018-03-30 CRAN (R 3.4.4)
#>  purrr        * 0.2.4    2017-10-18 CRAN (R 3.4.2)
#>  R6             2.2.2    2017-06-17 CRAN (R 3.4.2)
#>  randomForest * 4.6-14   2018-03-25 CRAN (R 3.4.4)
#>  Rcpp           0.12.16  2018-03-13 CRAN (R 3.4.4)
#>  RcppRoll       0.2.2    2015-04-05 CRAN (R 3.4.2)
#>  readr        * 1.1.1    2017-05-16 CRAN (R 3.4.2)
#>  readxl         1.0.0    2017-04-18 CRAN (R 3.4.2)
#>  recipes        0.1.2    2018-01-11 CRAN (R 3.4.3)
#>  reshape2       1.4.3    2017-12-11 CRAN (R 3.4.3)
#>  rlang          0.2.0    2018-02-20 CRAN (R 3.4.3)
#>  rmarkdown      1.9      2018-03-01 CRAN (R 3.4.3)
#>  robustbase     0.92-8   2017-11-01 CRAN (R 3.4.2)
#>  rpart          4.1-11   2017-03-13 CRAN (R 3.4.2)
#>  rprojroot      1.3-2    2018-01-03 CRAN (R 3.4.3)
#>  rstudioapi     0.7      2017-09-07 CRAN (R 3.4.2)
#>  rvest          0.3.2    2016-06-17 CRAN (R 3.4.3)
#>  scales         0.5.0    2017-08-24 CRAN (R 3.4.2)
#>  sfsmisc        1.1-2    2018-03-05 CRAN (R 3.4.3)
#>  splines        3.4.2    2017-09-28 local         
#>  stats        * 3.4.2    2017-09-28 local         
#>  stats4         3.4.2    2017-09-28 local         
#>  stringi        1.1.7    2018-03-12 CRAN (R 3.4.4)
#>  stringr      * 1.3.0    2018-02-19 CRAN (R 3.4.3)
#>  survival       2.41-3   2017-04-04 CRAN (R 3.4.2)
#>  tibble       * 1.4.2    2018-01-22 CRAN (R 3.4.3)
#>  tidyr        * 0.8.0    2018-01-29 CRAN (R 3.4.3)
#>  tidyselect     0.2.4    2018-02-26 CRAN (R 3.4.3)
#>  tidyverse    * 1.2.1    2017-11-14 CRAN (R 3.4.2)
#>  timeDate       3043.102 2018-02-21 CRAN (R 3.4.3)
#>  tools          3.4.2    2017-09-28 local         
#>  utils        * 3.4.2    2017-09-28 local         
#>  withr          2.1.2    2018-03-15 CRAN (R 3.4.2)
#>  xml2           1.2.0    2018-01-24 CRAN (R 3.4.3)
#>  yaml           2.1.18   2018-03-08 CRAN (R 3.4.3)
@topepo
Copy link
Owner

topepo commented Jul 13, 2018

  • If train() is expressed as a formula, then the importances will be given by factors' levels of the categorical variable.
  • If train() is NOT expressed as formula, then the importances will be given for the whole factor variable, as expected.

Your expectation is pretty reasonable. 99.9% of the time, a formula method will generate indicator variables for qualitative predictors. train is consistent with the majority of functions that use formulas.

However, there are a variety of package functions whose models do not require that all of the predictors be encoded as numbers. Trees, rule-based models, naive Bayes, and others fall into this bucket.

So, if you want to keep factors as factors, use the non-formula method for train.

@ciberger
Copy link

People could still benefit from the flexibility of formulas while getting keeping factors as factors using model.frame function. See workaround here #803

@topepo
Copy link
Owner

topepo commented Aug 16, 2018

I'm not for adding yet another option to trainControl and changing the default behavior would break a lot of reverse dependencies.

The recipe interface would solve these issues too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants