-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Look into creating sparse matrix interface #42
Comments
Moving to |
After an initial look, it appears that there are two places where changes are required. Before going into those, our first pass at this will constrain that
These constraints are so that the core tune/workflows changesFirst, forged <- forge_from_workflow(split, workflow)
x_vals <- forged$predictors
y_vals <- forged$outcomes For sparse matrices, we want To do this, we probably need a user-facing control option so that Also, we'll need to check to see that there are no factor or character predictors before doing so (which is easy with a recipe). parsnip changesThis may be the more invasive change. Right now, the main input formats in parsnip are "data.frame", "matrix", and "formula". This is specified in the set_fit(
model = "logistic_reg",
eng = "glmnet",
mode = "classification",
value = list(
interface = "matrix",
protect = c("x", "y", "weights"),
func = c(pkg = "glmnet", fun = "glmnet"),
defaults = list(family = "binomial")
)
)
For sparsity with as_either_matrix <- function(x) {
if (is.data.frame(x)) {
x <- as.matrix(x)
}
# leave alone if matrix or sparse matrix
x
} instead of For other models, additional changes might be required. For example, set_fit(
model = "rand_forest",
eng = "ranger",
mode = "classification",
value = list(
interface = "formula",
protect = c("formula", "data", "case.weights"),
func = c(pkg = "ranger", fun = "ranger"),
defaults =
list(
num.threads = 1,
verbose = FALSE,
seed = expr(sample.int(10 ^ 5, 1))
)
)
) This formula interface is inconsistent with sparse matrices as inputs. An
Although completely undocumented, reading the code shows that For this or other models/engines, we might have to fall back on using a wrapper in case the data frame and sparse matrix input interfaces are discordant or inconsistent. Models/packages that we will look into for this format are: Finally, once sparse matrices can be used, some benchmarking is needed to see if there is a big penalty for using small data sets and/or dense data in this format. |
Should we add a flag to the model/engine specification that allows for sparse set_encoding(
model = "logistic_reg",
eng = "glmnet",
mode = "classification",
options = list(
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE,
allow_sparse_x = TRUE #<- new option
)
) |
That is probably a good way to catch errors for now (when we are going to let users opt in) and then find the best option (sparse vs. not) later. |
I think that the remotes::install_github("tidymodels/parsnip@sparsity") Once we finalize things, I'll go to the I looked at |
Based on these two PRs, we should be good to go # See https://github.com/tidymodels/hardhat/pull/150
# remotes::install_github("tidymodels/hardhat@recipe-blueprint-composition")
# See https://github.com/tidymodels/parsnip/pull/373
# remotes::install_github("tidymodels/parsnip@sparsity")
library(tidymodels)
#> ── Attaching packages ─────────────────────────────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom 0.7.0 ✓ recipes 0.1.13
#> ✓ dials 0.0.9 ✓ rsample 0.0.8
#> ✓ dplyr 1.0.2 ✓ tibble 3.0.3
#> ✓ ggplot2 3.3.2 ✓ tidyr 1.1.2
#> ✓ infer 0.5.2 ✓ tune 0.1.1
#> ✓ modeldata 0.0.2 ✓ workflows 0.2.0
#> ✓ parsnip 0.1.3.9000 ✓ yardstick 0.0.7
#> ✓ purrr 0.3.4
#> ── Conflicts ────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
library(hardhat)
data(ames)
ames <-
ames %>%
mutate(Sale_Price = log10(Sale_Price)) %>%
select(Sale_Price, Longitude, Latitude, Neighborhood)
rec <-
recipe(Sale_Price ~ ., data = ames) %>%
step_dummy(Neighborhood) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
sparse_bp <- default_recipe_blueprint(composition = "dgCMatrix") processed <- mold(rec, ames, blueprint = sparse_bp)
class(forge(ames, blueprint = processed$blueprint)$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix" lm_spec <- linear_reg(penalty = 0.001, mixture = .5) %>% set_engine("glmnet")
lm_wflow <-
workflow() %>%
add_recipe(rec, blueprint = sparse_bp) %>%
add_model(lm_spec)
lm_wflow_1 <- .fit_pre(lm_wflow, data = ames)
class(lm_wflow_1$pre$mold$predictors)
#> [1] "dgCMatrix"
#> attr(,"package")
#> [1] "Matrix"
lm_wflow_2 <- .fit_model(lm_wflow_1, control = control_workflow()) fit_resamples(lm_wflow, vfold_cv(ames))
#> # Resampling results
#> # 10-fold cross-validation
#> # A tibble: 10 x 4
#> splits id .metrics .notes
#> <list> <chr> <list> <list>
#> 1 <split [2.6K/293]> Fold01 <tibble [2 × 3]> <tibble [0 × 1]>
#> 2 <split [2.6K/293]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#> 3 <split [2.6K/293]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#> 4 <split [2.6K/293]> Fold04 <tibble [2 × 3]> <tibble [0 × 1]>
#> 5 <split [2.6K/293]> Fold05 <tibble [2 × 3]> <tibble [0 × 1]>
#> 6 <split [2.6K/293]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#> 7 <split [2.6K/293]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#> 8 <split [2.6K/293]> Fold08 <tibble [2 × 3]> <tibble [0 × 1]>
#> 9 <split [2.6K/293]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [2.6K/293]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]> Created on 2020-10-01 by the reprex package (v0.3.0) Session infodevtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.2 (2020-06-22)
#> os macOS Catalina 10.15.5
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/New_York
#> date 2020-10-01
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.0)
#> backports 1.1.10 2020-09-15 [1] CRAN (R 4.0.2)
#> broom * 0.7.0 2020-07-09 [1] CRAN (R 4.0.0)
#> callr 3.4.4 2020-09-07 [1] CRAN (R 4.0.2)
#> class 7.3-17 2020-04-26 [1] CRAN (R 4.0.2)
#> cli 2.0.2 2020-02-28 [1] CRAN (R 4.0.0)
#> codetools 0.2-16 2018-12-24 [1] CRAN (R 4.0.2)
#> colorspace 1.4-1 2019-03-18 [1] CRAN (R 4.0.0)
#> crayon 1.3.4.9000 2020-08-18 [1] Github (r-lib/crayon@6b3f0c6)
#> desc 1.2.0 2018-05-01 [1] CRAN (R 4.0.0)
#> devtools 2.3.1 2020-07-21 [1] CRAN (R 4.0.2)
#> dials * 0.0.9 2020-09-16 [1] CRAN (R 4.0.2)
#> DiceDesign 1.8-1 2019-07-31 [1] CRAN (R 4.0.0)
#> digest 0.6.25 2020-02-23 [1] CRAN (R 4.0.0)
#> dplyr * 1.0.2 2020-08-18 [1] CRAN (R 4.0.0)
#> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.0)
#> evaluate 0.14 2019-05-28 [1] CRAN (R 4.0.0)
#> fansi 0.4.1 2020-01-08 [1] CRAN (R 4.0.0)
#> foreach 1.5.0 2020-03-30 [1] CRAN (R 4.0.2)
#> fs 1.5.0 2020-07-31 [1] CRAN (R 4.0.2)
#> furrr 0.1.0 2018-05-16 [1] CRAN (R 4.0.0)
#> future 1.19.1 2020-09-22 [1] CRAN (R 4.0.2)
#> generics 0.0.2 2018-11-29 [1] CRAN (R 4.0.0)
#> ggplot2 * 3.3.2 2020-06-19 [1] CRAN (R 4.0.0)
#> glmnet 4.0-2 2020-06-16 [1] CRAN (R 4.0.2)
#> globals 0.13.0 2020-09-17 [1] CRAN (R 4.0.2)
#> glue 1.4.2 2020-08-27 [1] CRAN (R 4.0.2)
#> gower 0.2.2 2020-06-23 [1] CRAN (R 4.0.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.0.0)
#> gtable 0.3.0 2019-03-25 [1] CRAN (R 4.0.0)
#> hardhat * 0.1.4.9000 2020-10-01 [1] Github (tidymodels/hardhat@4329fa9)
#> highr 0.8 2019-03-20 [1] CRAN (R 4.0.0)
#> htmltools 0.5.0 2020-06-16 [1] CRAN (R 4.0.0)
#> infer * 0.5.2 2020-06-14 [1] CRAN (R 4.0.0)
#> ipred 0.9-9 2019-04-28 [1] CRAN (R 4.0.2)
#> iterators 1.0.12 2019-07-26 [1] CRAN (R 4.0.0)
#> knitr 1.30 2020-09-22 [1] CRAN (R 4.0.2)
#> lattice 0.20-41 2020-04-02 [1] CRAN (R 4.0.2)
#> lava 1.6.8 2020-09-26 [1] CRAN (R 4.0.2)
#> lhs 1.1.0 2020-09-29 [1] CRAN (R 4.0.2)
#> lifecycle 0.2.0 2020-03-06 [1] CRAN (R 4.0.0)
#> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.0.0)
#> lubridate 1.7.9 2020-06-08 [1] CRAN (R 4.0.2)
#> magrittr 1.5 2014-11-22 [1] CRAN (R 4.0.0)
#> MASS 7.3-51.6 2020-04-26 [1] CRAN (R 4.0.2)
#> Matrix 1.2-18 2019-11-27 [1] CRAN (R 4.0.2)
#> memoise 1.1.0 2017-04-21 [1] CRAN (R 4.0.0)
#> modeldata * 0.0.2 2020-06-22 [1] CRAN (R 4.0.2)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.0.0)
#> nnet 7.3-14 2020-04-26 [1] CRAN (R 4.0.2)
#> parsnip * 0.1.3.9000 2020-09-30 [1] Github (tidymodels/parsnip@2f2737a)
#> pillar 1.4.6 2020-07-10 [1] CRAN (R 4.0.0)
#> pkgbuild 1.1.0 2020-07-13 [1] CRAN (R 4.0.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.0)
#> pkgload 1.1.0 2020-05-29 [1] CRAN (R 4.0.0)
#> plyr 1.8.6 2020-03-03 [1] CRAN (R 4.0.2)
#> prettyunits 1.1.1 2020-01-24 [1] CRAN (R 4.0.0)
#> pROC 1.16.2 2020-03-19 [1] CRAN (R 4.0.2)
#> processx 3.4.4 2020-09-03 [1] CRAN (R 4.0.2)
#> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)
#> ps 1.3.4 2020-08-11 [1] CRAN (R 4.0.2)
#> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.0.0)
#> R6 2.4.1 2019-11-12 [1] CRAN (R 4.0.0)
#> Rcpp 1.0.5 2020-07-06 [1] CRAN (R 4.0.0)
#> recipes * 0.1.13 2020-06-23 [1] CRAN (R 4.0.2)
#> remotes 2.2.0 2020-07-21 [1] CRAN (R 4.0.2)
#> rlang 0.4.7 2020-07-09 [1] CRAN (R 4.0.0)
#> rmarkdown 2.3 2020-06-18 [1] CRAN (R 4.0.2)
#> rpart 4.1-15 2019-04-12 [1] CRAN (R 4.0.2)
#> rprojroot 1.3-2 2018-01-03 [1] CRAN (R 4.0.0)
#> rsample * 0.0.8 2020-09-23 [1] CRAN (R 4.0.2)
#> rstudioapi 0.11 2020-02-07 [1] CRAN (R 4.0.0)
#> scales * 1.1.1 2020-05-11 [1] CRAN (R 4.0.2)
#> sessioninfo 1.1.1 2018-11-05 [1] CRAN (R 4.0.2)
#> shape 1.4.5 2020-09-13 [1] CRAN (R 4.0.2)
#> stringi 1.5.3 2020-09-09 [1] CRAN (R 4.0.2)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.0)
#> survival 3.1-12 2020-04-10 [1] CRAN (R 4.0.2)
#> testthat 2.3.2 2020-03-02 [1] CRAN (R 4.0.2)
#> tibble * 3.0.3 2020-07-10 [1] CRAN (R 4.0.0)
#> tidymodels * 0.1.1 2020-07-14 [1] CRAN (R 4.0.0)
#> tidyr * 1.1.2 2020-08-27 [1] CRAN (R 4.0.2)
#> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.0)
#> timeDate 3043.102 2018-02-21 [1] CRAN (R 4.0.0)
#> tune * 0.1.1 2020-07-08 [1] CRAN (R 4.0.2)
#> usethis 1.9.0.9000 2020-09-30 [1] Github (r-lib/usethis@b993e83)
#> utf8 1.1.4 2018-05-24 [1] CRAN (R 4.0.0)
#> vctrs 0.3.4 2020-08-29 [1] CRAN (R 4.0.2)
#> withr 2.3.0 2020-09-22 [1] CRAN (R 4.0.2)
#> workflows * 0.2.0 2020-09-15 [1] CRAN (R 4.0.2)
#> xfun 0.17 2020-09-09 [1] CRAN (R 4.0.2)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.0)
#> yardstick * 0.0.7 2020-07-13 [1] CRAN (R 4.0.2)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library |
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
Can the issue of passing a sparse format to models, laid out in tidymodels/hardhat#100, be solved (in the short term or perhaps forever) by creating a
sparse_matrix
interface for models, similar to thematrix
interface?Would it hurt performance with models like glmnet for not-actually-sparse data to be passed in via sparse format? If so, can that be worked around?
The text was updated successfully, but these errors were encountered: