Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Imports:
tidyr (>= 1.0.0),
globals,
prettyunits,
vctrs (>= 0.2.0)
vctrs (>= 0.2.0),
hardhat (>= 0.1.5.9000)
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.1.9001
Suggests:
Expand Down Expand Up @@ -59,4 +60,6 @@ Suggests:
dials (>= 0.0.9.9000)
Remotes:
tidymodels/dials,
topepo/C5.0
topepo/C5.0,
tidymodels/hardhat

6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method(augment,model_fit)
S3method(extract_fit_engine,model_fit)
S3method(extract_spec_parsnip,model_fit)
S3method(fit,model_spec)
S3method(fit_xy,gen_additive_mod)
S3method(fit_xy,model_spec)
Expand Down Expand Up @@ -138,6 +140,8 @@ export(control_parsnip)
export(convert_stan_interval)
export(decision_tree)
export(eval_args)
export(extract_fit_engine)
export(extract_spec_parsnip)
export(find_engine_files)
export(fit)
export(fit.model_spec)
Expand Down Expand Up @@ -259,6 +263,8 @@ importFrom(generics,required_pkgs)
importFrom(generics,tidy)
importFrom(generics,varying_args)
importFrom(glue,glue_collapse)
importFrom(hardhat,extract_fit_engine)
importFrom(hardhat,extract_spec_parsnip)
importFrom(magrittr,"%>%")
importFrom(purrr,as_vector)
importFrom(purrr,imap)
Expand Down
67 changes: 67 additions & 0 deletions R/extract.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#' Extract elements of a parsnip model object
#'
#' @description
#' These functions extract various elements from a parsnip object. If they do
#' not exist yet, an error is thrown.
#'
#' - `extract_spec_parsnip()` returns the parsnip model specification.
#'
#' - `extract_fit_engine()` returns the engine specific fit embedded within
#' a parsnip model fit. For example, when using [parsnip::linear_reg()]
#' with the `"lm"` engine, this returns the underlying `lm` object.
#'
#' @param x A parsnip `model_fit` object.
#' @param ... Not currently used.
#' @details
#' Extracting the underlying engine fit can be helpful for describing the
#' model (via `print()`, `summary()`, `plot()`, etc.) or for variable
#' importance/explainers.
#'
#' However, users should not invoke the `predict()` method on an extracted
#' model. There may be preprocessing operations that `parsnip` has executed on
#' the data prior to giving it to the model. Bypassing these can lead to errors
#' or silently generating incorrect predictions.
#'
#' **Good**:
#' ```r
#' parsnip_fit %>% predict(new_data)
#' ```
#'
#' **Bad**:
#' ```r
#' parsnip_fit %>% extract_fit_engine() %>% predict(new_data)
#' ```
#' @return
#' The extracted value from the parsnip object, `x`, as described in the description
#' section.
#'
#' @name extract-parsnip
#' @examples
#' lm_spec <- linear_reg() %>% set_engine("lm")
#' lm_fit <- fit(lm_spec, mpg ~ ., data = mtcars)
#'
#' lm_spec
#' extract_spec_parsnip(lm_fit)
#'
#' extract_fit_engine(lm_fit)
#' lm(mpg ~ ., data = mtcars)
NULL

#' @export
#' @rdname extract-parsnip
extract_spec_parsnip.model_fit <- function(x, ...) {
if (any(names(x) == "spec")) {
return(x$spec)
}
rlang::abort("Internal error: The model fit does not have a model spec.")
}


#' @export
#' @rdname extract-parsnip
extract_fit_engine.model_fit <- function(x, ...) {
if (any(names(x) == "fit")) {
return(x$fit)
}
rlang::abort("Internal error: The model fit does not have an engine fit.")
}
8 changes: 8 additions & 0 deletions R/reexports.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,11 @@ generics::augment
#' @importFrom generics required_pkgs
#' @export
generics::required_pkgs

#' @importFrom hardhat extract_spec_parsnip
#' @export
hardhat::extract_spec_parsnip

#' @importFrom hardhat extract_fit_engine
#' @export
hardhat::extract_fit_engine
57 changes: 57 additions & 0 deletions man/extract-parsnip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test-extract.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

context("model extraction")

# ------------------------------------------------------------------------------

test_that('extract', {
x <- linear_reg() %>% set_engine("lm") %>% fit(mpg ~ ., data = mtcars)
x_no_spec <- x
x_no_spec$spec <- NULL
x_no_fit <- x
x_no_fit$fit <- NULL

expect_true(inherits(extract_spec_parsnip(x), "model_spec"))
expect_true(inherits(extract_fit_engine(x), "lm"))

expect_error(extract_spec_parsnip(x_no_spec), "Internal error")
expect_error(extract_fit_engine(x_no_fit), "Internal error")
})