-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sparse matrix support for LiblineaR engine #447
Conversation
The way that the LiblineaR algorithm works means that we don't necessarily get performance gains from using sparse data: library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(textrecipes)
data("small_fine_foods")
sparse_bp <- hardhat::default_recipe_blueprint(composition = "dgCMatrix")
text_rec <-
recipe(score ~ review, data = training_data) %>%
step_tokenize(review) %>%
step_stopwords(review) %>%
step_tokenfilter(review, max_tokens = 1e3) %>%
step_tfidf(review)
liblinear_spec <- logistic_reg(penalty = 0.02, mixture = 1) %>%
set_engine("LiblineaR") %>%
set_mode("classification")
glmnet_spec <- logistic_reg(penalty = 0.02, mixture = 1) %>%
set_engine("glmnet") %>%
set_mode("classification")
liblinear_sparse <-
workflow() %>%
add_recipe(text_rec, blueprint = sparse_bp) %>%
add_model(liblinear_spec)
liblinear_default <-
workflow() %>%
add_recipe(text_rec) %>%
add_model(liblinear_spec)
glmnet_sparse <-
workflow() %>%
add_recipe(text_rec, blueprint = sparse_bp) %>%
add_model(glmnet_spec)
glmnet_default <-
workflow() %>%
add_recipe(text_rec) %>%
add_model(glmnet_spec)
set.seed(123)
food_folds <- vfold_cv(training_data, v = 3)
library(bench)
results <- mark(
iterations = 10, check = FALSE,
liblinear_sparse = fit_resamples(liblinear_sparse, food_folds),
liblinear_default = fit_resamples(liblinear_default, food_folds),
glmnet_sparse = fit_resamples(glmnet_sparse, food_folds),
glmnet_default = fit_resamples(glmnet_default, food_folds),
)
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
results
#> # A tibble: 4 x 6
#> expression min median `itr/sec` mem_alloc `gc/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
#> 1 liblinear_sparse 2.7s 2.73s 0.362 867MB 1.20
#> 2 liblinear_default 2.71s 2.73s 0.366 804MB 1.10
#> 3 glmnet_sparse 7.67s 7.71s 0.130 794MB 0.389
#> 4 glmnet_default 1.18m 1.21m 0.0138 933MB 0.0828
autoplot(results, type = "ridge")
#> Picking joint bandwidth of 0.00186 Created on 2021-03-18 by the reprex package (v1.0.0) But we still want to set up the possibility, in case we add the ability to have sparse representations inside of recipes later or someone brings a sparse matrix straight to parsnip, etc. |
Is there a place in the parsnip documentation that tells you if an engine accepts sparse matrices? |
This is what we landed on for now: In
And then you can call library(parsnip)
get_encoding("linear_reg")
#> # A tibble: 5 x 7
#> model engine mode predictor_indicato… compute_interce… remove_intercept allow_sparse_x
#> <chr> <chr> <chr> <chr> <lgl> <lgl> <lgl>
#> 1 linear_… lm regress… traditional TRUE TRUE FALSE
#> 2 linear_… glmnet regress… traditional TRUE TRUE TRUE
#> 3 linear_… stan regress… traditional TRUE TRUE FALSE
#> 4 linear_… spark regress… traditional TRUE TRUE FALSE
#> 5 linear_… keras regress… traditional TRUE TRUE FALSE
get_encoding("svm_linear")
#> # A tibble: 4 x 7
#> model engine mode predictor_indica… compute_interce… remove_intercept allow_sparse_x
#> <chr> <chr> <chr> <chr> <lgl> <lgl> <lgl>
#> 1 svm_lin… Libline… regress… none FALSE FALSE TRUE
#> 2 svm_lin… Libline… classif… none FALSE FALSE TRUE
#> 3 svm_lin… kernlab regress… none FALSE FALSE FALSE
#> 4 svm_lin… kernlab classif… none FALSE FALSE FALSE |
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
This PR adds sparse matrix support for the LiblineaR engine models. Closes #434.
Created on 2021-03-18 by the reprex package (v1.0.0)