Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Imports:
generics,
glue,
hardhat (>= 0.1.4),
parsnip (>= 0.1.2),
parsnip (>= 0.1.3),
rlang (>= 0.4.1)
Suggests:
covr,
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# workflows (development version)

* A test has been updated to reflect a change in parsnip 0.1.3 regarding how
intercept columns are removed during prediction (#65).

# workflows 0.1.2

* When using a formula preprocessor with `add_formula()`, workflows now uses
Expand Down
52 changes: 42 additions & 10 deletions tests/testthat/test-predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,57 @@ test_that("`new_data` must have all of the original predictors", {
})

test_that("blueprint will get passed on to hardhat::forge()", {
train <- data.frame(
y = c(1L, 5L, 3L, 4L),
x = factor(c("x", "y", "x", "y"))
)

test <- data.frame(
x = factor(c("x", "y", "z"))
)

spec <- parsnip::linear_reg()
spec <- parsnip::set_engine(spec, "lm")

bp1 <- hardhat::default_formula_blueprint(intercept = TRUE, allow_novel_levels = FALSE)
bp2 <- hardhat::default_formula_blueprint(intercept = TRUE, allow_novel_levels = TRUE)

workflow <- workflow()
workflow <- add_model(workflow, spec)

workflow1 <- add_formula(workflow, y ~ x, blueprint = bp1)
workflow2 <- add_formula(workflow, y ~ x, blueprint = bp2)

mod1 <- fit(workflow1, train)
mod2 <- fit(workflow2, train)

expect_warning(pred1 <- predict(mod1, test))
expect_warning(pred2 <- predict(mod2, test), NA)

expect_identical(
pred1[[".pred"]],
c(2, 4.5, NA)
)

expect_identical(
pred2[[".pred"]],
c(2, 4.5, 2)
)
})

test_that("monitoring: known that parsnip removes blueprint intercept for some models (tidymodels/parsnip#353)", {
mod <- parsnip::linear_reg()
mod <- parsnip::set_engine(mod, "lm")

# Pass formula explicitly to keep `lm()` from auto-generating an intercept
workflow <- workflow()
workflow <- add_model(workflow, mod, formula = mpg ~ . + 0)

blueprint_no_intercept <- hardhat::default_formula_blueprint(intercept = FALSE)
workflow_no_intercept <- add_formula(workflow, mpg ~ hp + disp, blueprint = blueprint_no_intercept)
fit_no_intercept <- fit(workflow_no_intercept, mtcars)
prediction_no_intercept <- predict(fit_no_intercept, mtcars)

blueprint_with_intercept <- hardhat::default_formula_blueprint(intercept = TRUE)
workflow_with_intercept <- add_formula(workflow, mpg ~ hp + disp, blueprint = blueprint_with_intercept)
fit_with_intercept <- fit(workflow_with_intercept, mtcars)
prediction_with_intercept <- predict(fit_with_intercept, mtcars)

expect_false(fit_no_intercept$pre$mold$blueprint$intercept)
expect_true(fit_with_intercept$pre$mold$blueprint$intercept)

expect_false(identical(prediction_with_intercept, prediction_no_intercept))
# `parsnip:::prepare_data()` will remove the intercept, so it won't be
# there when the `lm()` `predict()` method is called.
expect_error(predict(fit_with_intercept, mtcars))
})