diff --git a/NAMESPACE b/NAMESPACE index a214f738d..87f703ad2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,9 @@ S3method(fit,model_spec) S3method(fit_xy,model_spec) +S3method(has_multi_pred,default) +S3method(has_multi_pred,model_fit) +S3method(has_multi_pred,workflow) S3method(multi_predict,"_C5.0") S3method(multi_predict,"_earth") S3method(multi_predict,"_elnet") @@ -91,6 +94,7 @@ export(get_fit) export(get_from_env) export(get_model_env) export(get_pred_type) +export(has_multi_pred) export(keras_mlp) export(linear_reg) export(logistic_reg) @@ -210,4 +214,5 @@ importFrom(utils,capture.output) importFrom(utils,getFromNamespace) importFrom(utils,globalVariables) importFrom(utils,head) +importFrom(utils,methods) importFrom(vctrs,vec_unique) diff --git a/R/predict.R b/R/predict.R index 414096d1a..c01c7f5e2 100644 --- a/R/predict.R +++ b/R/predict.R @@ -261,3 +261,44 @@ multi_predict.default <- function(object, ...) predict.model_spec <- function(object, ...) { stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE) } + + + +#' Determine if a model can make predictions on sub-models +#' +#' @param object An object to test. +#' @param ... Not currently used. +#' @return A single logical value. +#' @keywords internal +#' @examples +#' model_idea <- linear_reg() %>% set_engine("lm") +#' has_multi_pred(model_idea) +#' model_fit <- fit(model_idea, mpg ~ ., data = mtcars) +#' has_multi_pred(model_fit) +#' @importFrom utils methods +#' @export +has_multi_pred <- function(object, ...) { + UseMethod("has_multi_pred") +} + +#' @export +#' @rdname has_multi_pred +has_multi_pred.default <- function(object, ...) { + FALSE +} + +#' @export +#' @rdname has_multi_pred +has_multi_pred.model_fit <- function(object, ...) { + existing_mthds <- utils::methods("multi_predict") + tst <- paste0("multi_predict.", class(object)) + any(tst %in% existing_mthds) +} + +#' @export +#' @rdname has_multi_pred +has_multi_pred.workflow <- function(object, ...) { + has_multi_pred(object$fit$model$model) +} + + diff --git a/man/has_multi_pred.Rd b/man/has_multi_pred.Rd new file mode 100644 index 000000000..99f850dd9 --- /dev/null +++ b/man/has_multi_pred.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/predict.R +\name{has_multi_pred} +\alias{has_multi_pred} +\alias{has_multi_pred.default} +\alias{has_multi_pred.model_fit} +\alias{has_multi_pred.workflow} +\title{Determine if a model can make predictions on sub-models} +\usage{ +has_multi_pred(object, ...) + +\method{has_multi_pred}{default}(object, ...) + +\method{has_multi_pred}{model_fit}(object, ...) + +\method{has_multi_pred}{workflow}(object, ...) +} +\arguments{ +\item{object}{An object to test.} + +\item{...}{Not currently used.} +} +\value{ +A single logical value. +} +\description{ +Determine if a model can make predictions on sub-models +} +\examples{ +model_idea <- linear_reg() \%>\% set_engine("lm") +has_multi_pred(model_idea) +model_fit <- fit(model_idea, mpg ~ ., data = mtcars) +has_multi_pred(model_fit) +} +\keyword{internal} diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R new file mode 100644 index 000000000..089431941 --- /dev/null +++ b/tests/testthat/test_misc.R @@ -0,0 +1,31 @@ + +# ------------------------------------------------------------------------------ + +context("checking for multi_predict") + +test_that('parsnip objects', { + + lm_idea <- linear_reg() %>% set_engine("lm") + expect_false(has_multi_pred(lm_idea)) + + lm_fit <- fit(lm_idea, mpg ~ ., data = mtcars) + expect_false(has_multi_pred(lm_fit)) + expect_false(has_multi_pred(lm_fit$fit)) + + mars_fit <- + mars(mode = "regression") %>% + set_engine("earth") %>% + fit(mpg ~ ., data = mtcars) + expect_true(has_multi_pred(mars_fit)) + expect_false(has_multi_pred(mars_fit$fit)) +}) + +test_that('other objects', { + + expect_false(has_multi_pred(NULL)) + expect_false(has_multi_pred(NA)) + +}) + +# ------------------------------------------------------------------------------ +