diff --git a/DESCRIPTION b/DESCRIPTION index b8a008f26..87fc49044 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,7 +31,8 @@ Imports: tidyr (>= 1.0.0), globals, prettyunits, - vctrs (>= 0.2.0) + vctrs (>= 0.2.0), + hardhat (>= 0.1.5.9000) Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.1.9001 Suggests: @@ -59,4 +60,6 @@ Suggests: dials (>= 0.0.9.9000) Remotes: tidymodels/dials, - topepo/C5.0 + topepo/C5.0, + tidymodels/hardhat + diff --git a/NAMESPACE b/NAMESPACE index d3b33bb6a..962fc8df3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,8 @@ # Generated by roxygen2: do not edit by hand S3method(augment,model_fit) +S3method(extract_fit_engine,model_fit) +S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) @@ -138,6 +140,8 @@ export(control_parsnip) export(convert_stan_interval) export(decision_tree) export(eval_args) +export(extract_fit_engine) +export(extract_spec_parsnip) export(find_engine_files) export(fit) export(fit.model_spec) @@ -259,6 +263,8 @@ importFrom(generics,required_pkgs) importFrom(generics,tidy) importFrom(generics,varying_args) importFrom(glue,glue_collapse) +importFrom(hardhat,extract_fit_engine) +importFrom(hardhat,extract_spec_parsnip) importFrom(magrittr,"%>%") importFrom(purrr,as_vector) importFrom(purrr,imap) diff --git a/R/extract.R b/R/extract.R new file mode 100644 index 000000000..c26c47fb0 --- /dev/null +++ b/R/extract.R @@ -0,0 +1,67 @@ +#' Extract elements of a parsnip model object +#' +#' @description +#' These functions extract various elements from a parsnip object. If they do +#' not exist yet, an error is thrown. +#' +#' - `extract_spec_parsnip()` returns the parsnip model specification. +#' +#' - `extract_fit_engine()` returns the engine specific fit embedded within +#' a parsnip model fit. For example, when using [parsnip::linear_reg()] +#' with the `"lm"` engine, this returns the underlying `lm` object. +#' +#' @param x A parsnip `model_fit` object. +#' @param ... Not currently used. +#' @details +#' Extracting the underlying engine fit can be helpful for describing the +#' model (via `print()`, `summary()`, `plot()`, etc.) or for variable +#' importance/explainers. +#' +#' However, users should not invoke the `predict()` method on an extracted +#' model. There may be preprocessing operations that `parsnip` has executed on +#' the data prior to giving it to the model. Bypassing these can lead to errors +#' or silently generating incorrect predictions. +#' +#' **Good**: +#' ```r +#' parsnip_fit %>% predict(new_data) +#' ``` +#' +#' **Bad**: +#' ```r +#' parsnip_fit %>% extract_fit_engine() %>% predict(new_data) +#' ``` +#' @return +#' The extracted value from the parsnip object, `x`, as described in the description +#' section. +#' +#' @name extract-parsnip +#' @examples +#' lm_spec <- linear_reg() %>% set_engine("lm") +#' lm_fit <- fit(lm_spec, mpg ~ ., data = mtcars) +#' +#' lm_spec +#' extract_spec_parsnip(lm_fit) +#' +#' extract_fit_engine(lm_fit) +#' lm(mpg ~ ., data = mtcars) +NULL + +#' @export +#' @rdname extract-parsnip +extract_spec_parsnip.model_fit <- function(x, ...) { + if (any(names(x) == "spec")) { + return(x$spec) + } + rlang::abort("Internal error: The model fit does not have a model spec.") +} + + +#' @export +#' @rdname extract-parsnip +extract_fit_engine.model_fit <- function(x, ...) { + if (any(names(x) == "fit")) { + return(x$fit) + } + rlang::abort("Internal error: The model fit does not have an engine fit.") +} diff --git a/R/reexports.R b/R/reexports.R index c6fa3d9f3..65a0025ef 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -26,3 +26,11 @@ generics::augment #' @importFrom generics required_pkgs #' @export generics::required_pkgs + +#' @importFrom hardhat extract_spec_parsnip +#' @export +hardhat::extract_spec_parsnip + +#' @importFrom hardhat extract_fit_engine +#' @export +hardhat::extract_fit_engine diff --git a/man/extract-parsnip.Rd b/man/extract-parsnip.Rd new file mode 100644 index 000000000..f1e30a0ea --- /dev/null +++ b/man/extract-parsnip.Rd @@ -0,0 +1,57 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract.R +\name{extract-parsnip} +\alias{extract-parsnip} +\alias{extract_spec_parsnip.model_fit} +\alias{extract_fit_engine.model_fit} +\title{Extract elements of a parsnip model object} +\usage{ +\method{extract_spec_parsnip}{model_fit}(x, ...) + +\method{extract_fit_engine}{model_fit}(x, ...) +} +\arguments{ +\item{x}{A parsnip \code{model_fit} object.} + +\item{...}{Not currently used.} +} +\value{ +The extracted value from the parsnip object, \code{x}, as described in the description +section. +} +\description{ +These functions extract various elements from a parsnip object. If they do +not exist yet, an error is thrown. +\itemize{ +\item \code{extract_spec_parsnip()} returns the parsnip model specification. +\item \code{extract_fit_engine()} returns the engine specific fit embedded within +a parsnip model fit. For example, when using \code{\link[=linear_reg]{linear_reg()}} +with the \code{"lm"} engine, this would return the underlying \code{lm} object. +} +} +\details{ +Extracting the underlying engine fit can be helpful for describing the +model (via \code{print()}, \code{summary()}, \code{plot()}, etc.) or for variable +importance/explainers. + +However, users should not invoke the \code{predict()} method on an extracted +model. There may be preprocessing operations that \code{parsnip} has executed on +the data prior to giving it to the model. Bypassing these can lead to errors +or silently generating incorrect predictions. + +\strong{Good}:\if{html}{\out{