diff --git a/DESCRIPTION b/DESCRIPTION index b8a008f26..87fc49044 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,7 +31,8 @@ Imports: tidyr (>= 1.0.0), globals, prettyunits, - vctrs (>= 0.2.0) + vctrs (>= 0.2.0), + hardhat (>= 0.1.5.9000) Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.1.9001 Suggests: @@ -59,4 +60,6 @@ Suggests: dials (>= 0.0.9.9000) Remotes: tidymodels/dials, - topepo/C5.0 + topepo/C5.0, + tidymodels/hardhat + diff --git a/NAMESPACE b/NAMESPACE index d3b33bb6a..962fc8df3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,8 @@ # Generated by roxygen2: do not edit by hand S3method(augment,model_fit) +S3method(extract_fit_engine,model_fit) +S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) @@ -138,6 +140,8 @@ export(control_parsnip) export(convert_stan_interval) export(decision_tree) export(eval_args) +export(extract_fit_engine) +export(extract_spec_parsnip) export(find_engine_files) export(fit) export(fit.model_spec) @@ -259,6 +263,8 @@ importFrom(generics,required_pkgs) importFrom(generics,tidy) importFrom(generics,varying_args) importFrom(glue,glue_collapse) +importFrom(hardhat,extract_fit_engine) +importFrom(hardhat,extract_spec_parsnip) importFrom(magrittr,"%>%") importFrom(purrr,as_vector) importFrom(purrr,imap) diff --git a/R/extract.R b/R/extract.R new file mode 100644 index 000000000..c26c47fb0 --- /dev/null +++ b/R/extract.R @@ -0,0 +1,67 @@ +#' Extract elements of a parsnip model object +#' +#' @description +#' These functions extract various elements from a parsnip object. If they do +#' not exist yet, an error is thrown. +#' +#' - `extract_spec_parsnip()` returns the parsnip model specification. +#' +#' - `extract_fit_engine()` returns the engine specific fit embedded within +#' a parsnip model fit. For example, when using [parsnip::linear_reg()] +#' with the `"lm"` engine, this returns the underlying `lm` object. +#' +#' @param x A parsnip `model_fit` object. +#' @param ... Not currently used. +#' @details +#' Extracting the underlying engine fit can be helpful for describing the +#' model (via `print()`, `summary()`, `plot()`, etc.) or for variable +#' importance/explainers. +#' +#' However, users should not invoke the `predict()` method on an extracted +#' model. There may be preprocessing operations that `parsnip` has executed on +#' the data prior to giving it to the model. Bypassing these can lead to errors +#' or silently generating incorrect predictions. +#' +#' **Good**: +#' ```r +#' parsnip_fit %>% predict(new_data) +#' ``` +#' +#' **Bad**: +#' ```r +#' parsnip_fit %>% extract_fit_engine() %>% predict(new_data) +#' ``` +#' @return +#' The extracted value from the parsnip object, `x`, as described in the description +#' section. +#' +#' @name extract-parsnip +#' @examples +#' lm_spec <- linear_reg() %>% set_engine("lm") +#' lm_fit <- fit(lm_spec, mpg ~ ., data = mtcars) +#' +#' lm_spec +#' extract_spec_parsnip(lm_fit) +#' +#' extract_fit_engine(lm_fit) +#' lm(mpg ~ ., data = mtcars) +NULL + +#' @export +#' @rdname extract-parsnip +extract_spec_parsnip.model_fit <- function(x, ...) { + if (any(names(x) == "spec")) { + return(x$spec) + } + rlang::abort("Internal error: The model fit does not have a model spec.") +} + + +#' @export +#' @rdname extract-parsnip +extract_fit_engine.model_fit <- function(x, ...) { + if (any(names(x) == "fit")) { + return(x$fit) + } + rlang::abort("Internal error: The model fit does not have an engine fit.") +} diff --git a/R/reexports.R b/R/reexports.R index c6fa3d9f3..65a0025ef 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -26,3 +26,11 @@ generics::augment #' @importFrom generics required_pkgs #' @export generics::required_pkgs + +#' @importFrom hardhat extract_spec_parsnip +#' @export +hardhat::extract_spec_parsnip + +#' @importFrom hardhat extract_fit_engine +#' @export +hardhat::extract_fit_engine diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd new file mode 100644 index 000000000..f1e30a0ea --- /dev/null +++ b/man/extract-parsnip.Rd @@ -0,0 +1,57 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract-parsnip} +\alias{extract-parsnip} +\alias{extract_spec_parsnip.model_fit} +\alias{extract_fit_engine.model_fit} +\title{Extract elements of a parsnip model object} +\usage{ +\method{extract_spec_parsnip}{model_fit}(x, ...) + +\method{extract_fit_engine}{model_fit}(x, ...) +} +\arguments{ +\item{x}{A parsnip \code{model_fit} object.} + +\item{...}{Not currently used.} +} +\value{ +The extracted value from the parsnip object, \code{x}, as described in the description +section. +} +\description{ +These functions extract various elements from a parsnip object. If they do +not exist yet, an error is thrown. +\itemize{ +\item \code{extract_spec_parsnip()} returns the parsnip model specification. +\item \code{extract_fit_engine()} returns the engine specific fit embedded within +a parsnip model fit. For example, when using \code{\link[=linear_reg]{linear_reg()}} +with the \code{"lm"} engine, this would return the underlying \code{lm} object. +} +} +\details{ +Extracting the underlying engine fit can be helpful for describing the +model (via \code{print()}, \code{summary()}, \code{plot()}, etc.) or for variable +importance/explainers. + +However, users should not invoke the \code{predict()} method on an extracted +model. There may be preprocessing operations that \code{parsnip} has executed on +the data prior to giving it to the model. Bypassing these can lead to errors +or silently generating incorrect predictions. + +\strong{Good}:\if{html}{\out{
}}\preformatted{ parsnip_fit \%>\% predict(new_data) +}\if{html}{\out{
}} + +\strong{Bad}:\if{html}{\out{
}}\preformatted{ parsnip_fit \%>\% extract_fit_engine() \%>\% predict(new_data) +}\if{html}{\out{
}} +} +\examples{ +lm_spec <- linear_reg() \%>\% set_engine("lm") +lm_fit <- fit(lm_spec, mpg ~ ., data = mtcars) + +lm_spec +extract_spec_parsnip(lm_fit) + +extract_fit_engine(lm_fit) +lm(mpg ~ ., data = mtcars) +} diff --git a/man/reexports.Rd b/man/reexports.Rd index 614b0b3b9..fccef00df 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -10,6 +10,8 @@ \alias{glance} \alias{augment} \alias{required_pkgs} +\alias{extract_spec_parsnip} +\alias{extract_fit_engine} \alias{varying_args} \title{Objects exported from other packages} \keyword{internal} @@ -20,6 +22,8 @@ below to see their documentation. \describe{ \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{fit}}, \code{\link[generics]{fit_xy}}, \code{\link[generics]{glance}}, \code{\link[generics]{required_pkgs}}, \code{\link[generics]{tidy}}, \code{\link[generics]{varying_args}}} + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}} + \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} }} diff --git a/tests/testthat/test-extract.R b/tests/testthat/test-extract.R new file mode 100644 index 000000000..9a03b1f70 --- /dev/null +++ b/tests/testthat/test-extract.R @@ -0,0 +1,19 @@ + +context("model extraction") + +# ------------------------------------------------------------------------------ + +test_that('extract', { + x <- linear_reg() %>% set_engine("lm") %>% fit(mpg ~ ., data = mtcars) + x_no_spec <- x + x_no_spec$spec <- NULL + x_no_fit <- x + x_no_fit$fit <- NULL + + expect_true(inherits(extract_spec_parsnip(x), "model_spec")) + expect_true(inherits(extract_fit_engine(x), "lm")) + + expect_error(extract_spec_parsnip(x_no_spec), "Internal error") + expect_error(extract_fit_engine(x_no_fit), "Internal error") +}) +