Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Suggests:
tailor (>= 0.0.0.9001),
covr,
dials (>= 1.0.0),
glmnet,
knitr,
magrittr,
Matrix,
Expand All @@ -54,6 +55,7 @@ Config/Needs/website:
yardstick
Remotes:
tidymodels/rsample,
tidymodels/recipes,
tidymodels/parsnip,
tidymodels/tailor,
r-lib/sparsevctrs
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

* New `extract_fit_time()` method has been added that return the time it took to train the workflow (#191).

* `fit()` can now take dgCMatrix and sparse tibbles as data values when `add_recipe()` or `add_variables()` is used (#245, #258).

* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument (#261).

# workflows 1.1.4

* While `augment.workflow()` previously never returned a `.resid` column, the
Expand Down
4 changes: 4 additions & 0 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ predict.workflow <- function(object, new_data, type = NULL, opts = list(), ...)
))
}

if (is_sparse_matrix(new_data)) {
new_data <- sparsevctrs::coerce_to_sparse_tibble(new_data)
}

fit <- extract_fit_parsnip(workflow)
new_data <- forge_predictors(new_data, workflow)

Expand Down
49 changes: 49 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,52 @@ test_that("sparse matrices can be passed to `fit() - xy", {
# We expect 1 materialization - the outcome
expect_snapshot(wf_fit <- fit(wf_spec, hotel_data))
})

test_that("sparse tibble can be passed to `predict()`", {
skip_if_not_installed("glmnet")
# Make materialization of sparse vectors throw an error
# https://r-lib.github.io/sparsevctrs/dev/reference/sparsevctrs_options.html
withr::local_options("sparsevctrs.verbose_materialize" = 3)

hotel_data <- sparse_hotel_rates(tibble = TRUE)

spec <- parsnip::linear_reg(penalty = 0) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("glmnet")

rec <- recipes::recipe(avg_price_per_room ~ ., data = hotel_data)

wf_spec <- workflow() %>%
add_recipe(rec) %>%
add_model(spec)

wf_fit <- fit(wf_spec, hotel_data)

expect_no_error(predict(wf_fit, hotel_data))
})

test_that("sparse matrix can be passed to `predict()`", {
skip_if_not_installed("glmnet")
# Make materialization of sparse vectors throw a warning
# https://r-lib.github.io/sparsevctrs/dev/reference/sparsevctrs_options.html
withr::local_options("sparsevctrs.verbose_materialize" = 2)

hotel_data <- sparse_hotel_rates()

spec <- parsnip::linear_reg(penalty = 0) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("glmnet")

rec <- recipes::recipe(avg_price_per_room ~ ., data = hotel_data)

wf_spec <- workflow() %>%
add_recipe(rec) %>%
add_model(spec)

# We know that this will cause 1 warning due to the outcome
suppressWarnings(
wf_fit <- fit(wf_spec, hotel_data)
)

expect_no_warning(predict(wf_fit, hotel_data))
})