From 2f58384130c0401d4b90684e7c437215e6f86226 Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 15 Jul 2019 14:00:56 -0400 Subject: [PATCH] added a function to get multi_predict arg names --- NAMESPACE | 12 ++- R/aaa_multi_predict.R | 99 +++++++++++++++++++++ R/nearest_neighbor_data.R | 2 +- R/predict.R | 70 --------------- man/has_multi_pred.Rd | 35 -------- man/has_multi_predict.Rd | 65 ++++++++++++++ tests/testthat/test_boost_tree_C50.R | 3 + tests/testthat/test_boost_tree_xgboost.R | 3 + tests/testthat/test_linear_reg_glmnet.R | 5 +- tests/testthat/test_linear_reg_spark.R | 3 + tests/testthat/test_linear_reg_stan.R | 3 + tests/testthat/test_logistic_reg_glmnet.R | 5 +- tests/testthat/test_mars.R | 3 + tests/testthat/test_misc.R | 14 +-- tests/testthat/test_mlp_keras.R | 4 + tests/testthat/test_mlp_nnet.R | 4 + tests/testthat/test_multinom_reg_glmnet.R | 5 +- tests/testthat/test_nearest_neighbor_kknn.R | 5 +- tests/testthat/test_rand_forest_ranger.R | 3 + tests/testthat/test_surv_reg_flexsurv.R | 3 + tests/testthat/test_svm_poly.R | 5 +- tests/testthat/test_svm_rbf.R | 4 +- 22 files changed, 232 insertions(+), 123 deletions(-) delete mode 100644 man/has_multi_pred.Rd create mode 100644 man/has_multi_predict.Rd diff --git a/NAMESPACE b/NAMESPACE index ba4ce863d..37745ade0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,9 +2,9 @@ S3method(fit,model_spec) S3method(fit_xy,model_spec) -S3method(has_multi_pred,default) -S3method(has_multi_pred,model_fit) -S3method(has_multi_pred,workflow) +S3method(has_multi_predict,default) +S3method(has_multi_predict,model_fit) +S3method(has_multi_predict,workflow) S3method(multi_predict,"_C5.0") S3method(multi_predict,"_earth") S3method(multi_predict,"_elnet") @@ -13,6 +13,9 @@ S3method(multi_predict,"_multnet") S3method(multi_predict,"_train.kknn") S3method(multi_predict,"_xgb.Booster") S3method(multi_predict,default) +S3method(multi_predict_args,default) +S3method(multi_predict_args,model_fit) +S3method(multi_predict_args,workflow) S3method(nullmodel,default) S3method(predict,"_elnet") S3method(predict,"_lognet") @@ -95,7 +98,7 @@ export(get_fit) export(get_from_env) export(get_model_env) export(get_pred_type) -export(has_multi_pred) +export(has_multi_predict) export(keras_mlp) export(linear_reg) export(logistic_reg) @@ -104,6 +107,7 @@ export(mars) export(mlp) export(model_printer) export(multi_predict) +export(multi_predict_args) export(multinom_reg) export(nearest_neighbor) export(null_model) diff --git a/R/aaa_multi_predict.R b/R/aaa_multi_predict.R index 400cdf199..46c005365 100644 --- a/R/aaa_multi_predict.R +++ b/R/aaa_multi_predict.R @@ -33,3 +33,102 @@ multi_predict.default <- function(object, ...) predict.model_spec <- function(object, ...) { stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE) } + +#' Tools for models that predict on sub-models +#' +#' `has_multi_predict()` tests to see if an object can make multiple +#' predictions on submodels from the same object. `multi_predict_args()` +#' returns the names of the argments to `multi_predict()` for this model +#' (if any). +#' @param object An object to test. +#' @param ... Not currently used. +#' @return `has_multi_predict()` returns single logical value while +#' `multi_predict()` returns a character vector of argument names (or `NA` +#' if none exist). +#' @keywords internal +#' @examples +#' lm_model_idea <- linear_reg() %>% set_engine("lm") +#' has_multi_predict(lm_model_idea) +#' lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars) +#' has_multi_predict(lm_model_fit) +#' +#' multi_predict_args(lm_model_fit) +#' +#' library(kknn) +#' +#' knn_fit <- +#' nearest_neighbor(mode = "regression", neighbors = 5) %>% +#' set_engine("kknn") %>% +#' fit(mpg ~ ., mtcars) +#' +#' multi_predict_args(knn_fit) +#' +#' multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred +#' @importFrom utils methods +#' @export +has_multi_predict <- function(object, ...) { + UseMethod("has_multi_predict") +} + +#' @export +#' @rdname has_multi_predict +has_multi_predict.default <- function(object, ...) { + FALSE +} + +#' @export +#' @rdname has_multi_predict +has_multi_predict.model_fit <- function(object, ...) { + existing_mthds <- utils::methods("multi_predict") + tst <- paste0("multi_predict.", class(object)) + any(tst %in% existing_mthds) +} + +#' @export +#' @rdname has_multi_predict +has_multi_predict.workflow <- function(object, ...) { + has_multi_predict(object$fit$model$model) +} + + +#' @rdname has_multi_predict +#' @export +#' @rdname has_multi_predict +multi_predict_args <- function(object, ...) { + UseMethod("multi_predict_args") +} + +#' @export +#' @rdname has_multi_predict +multi_predict_args.default <- function(object, ...) { + if (inherits(object, "model_fit")) { + res <- multi_predict_args.model_fit(object, ...) + } else { + res <- NA_character_ + } + res +} + +#' @export +#' @rdname has_multi_predict +multi_predict_args.model_fit <- function(object, ...) { + existing_mthds <- methods("multi_predict") + cls <- class(object) + tst <- paste0("multi_predict.", cls) + .fn <- tst[tst %in% existing_mthds] + if (length(.fn) == 0) { + return(NA_character_) + } + + .fn <- getFromNamespace(.fn, ns = "parsnip") + omit <- c('object', 'new_data', 'type', '...') + args <- names(formals(.fn)) + args[!(args %in% omit)] +} + +#' @export +#' @rdname has_multi_predict +multi_predict_args.workflow <- function(object, ...) { + object <- object$fit$model$model + +} diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 2a85f70d0..7221eccb5 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -16,7 +16,7 @@ set_model_arg( parsnip = "neighbors", original = "ks", func = list(pkg = "dials", fun = "neighbors"), - has_submodel = FALSE + has_submodel = TRUE ) set_model_arg( model = "nearest_neighbor", diff --git a/R/predict.R b/R/predict.R index 8339d886c..c57adb7be 100644 --- a/R/predict.R +++ b/R/predict.R @@ -231,73 +231,3 @@ prepare_data <- function(object, new_data) { new_data } -# Define a generic to make multiple predictions for the same model object ------ - -#' Model predictions across many sub-models -#' -#' For some models, predictions can be made on sub-models in the model object. -#' @param object A `model_fit` object. -#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")` -#' such as `type`. -#' @return A tibble with the same number of rows as the data being predicted. -#' Mostly likely, there is a list-column named `.pred` that is a tibble with -#' multiple rows per sub-model. -#' @export -multi_predict <- function(object, ...) { - if (inherits(object$fit, "try-error")) { - warning("Model fit failed; cannot make predictions.", call. = FALSE) - return(NULL) - } - UseMethod("multi_predict") -} - -#' @export -#' @rdname multi_predict -multi_predict.default <- function(object, ...) - stop("No `multi_predict` method exists for objects with classes ", - paste0("'", class(), "'", collapse = ", "), call. = FALSE) - -#' @export -predict.model_spec <- function(object, ...) { - stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE) -} - - - -#' Determine if a model can make predictions on sub-models -#' -#' @param object An object to test. -#' @param ... Not currently used. -#' @return A single logical value. -#' @keywords internal -#' @examples -#' model_idea <- linear_reg() %>% set_engine("lm") -#' has_multi_pred(model_idea) -#' model_fit <- fit(model_idea, mpg ~ ., data = mtcars) -#' has_multi_pred(model_fit) -#' @importFrom utils methods -#' @export -has_multi_pred <- function(object, ...) { - UseMethod("has_multi_pred") -} - -#' @export -#' @rdname has_multi_pred -has_multi_pred.default <- function(object, ...) { - FALSE -} - -#' @export -#' @rdname has_multi_pred -has_multi_pred.model_fit <- function(object, ...) { - existing_mthds <- utils::methods("multi_predict") - tst <- paste0("multi_predict.", class(object)) - any(tst %in% existing_mthds) -} - -#' @export -#' @rdname has_multi_pred -has_multi_pred.workflow <- function(object, ...) { - has_multi_pred(object$fit$model$model) -} - diff --git a/man/has_multi_pred.Rd b/man/has_multi_pred.Rd deleted file mode 100644 index 99f850dd9..000000000 --- a/man/has_multi_pred.Rd +++ /dev/null @@ -1,35 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict.R -\name{has_multi_pred} -\alias{has_multi_pred} -\alias{has_multi_pred.default} -\alias{has_multi_pred.model_fit} -\alias{has_multi_pred.workflow} -\title{Determine if a model can make predictions on sub-models} -\usage{ -has_multi_pred(object, ...) - -\method{has_multi_pred}{default}(object, ...) - -\method{has_multi_pred}{model_fit}(object, ...) - -\method{has_multi_pred}{workflow}(object, ...) -} -\arguments{ -\item{object}{An object to test.} - -\item{...}{Not currently used.} -} -\value{ -A single logical value. -} -\description{ -Determine if a model can make predictions on sub-models -} -\examples{ -model_idea <- linear_reg() \%>\% set_engine("lm") -has_multi_pred(model_idea) -model_fit <- fit(model_idea, mpg ~ ., data = mtcars) -has_multi_pred(model_fit) -} -\keyword{internal} diff --git a/man/has_multi_predict.Rd b/man/has_multi_predict.Rd new file mode 100644 index 000000000..64da8f233 --- /dev/null +++ b/man/has_multi_predict.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa_multi_predict.R +\name{has_multi_predict} +\alias{has_multi_predict} +\alias{has_multi_predict.default} +\alias{has_multi_predict.model_fit} +\alias{has_multi_predict.workflow} +\alias{multi_predict_args} +\alias{multi_predict_args.default} +\alias{multi_predict_args.model_fit} +\alias{multi_predict_args.workflow} +\title{Tools for models that predict on sub-models} +\usage{ +has_multi_predict(object, ...) + +\method{has_multi_predict}{default}(object, ...) + +\method{has_multi_predict}{model_fit}(object, ...) + +\method{has_multi_predict}{workflow}(object, ...) + +multi_predict_args(object, ...) + +\method{multi_predict_args}{default}(object, ...) + +\method{multi_predict_args}{model_fit}(object, ...) + +\method{multi_predict_args}{workflow}(object, ...) +} +\arguments{ +\item{object}{An object to test.} + +\item{...}{Not currently used.} +} +\value{ +\code{has_multi_predict()} returns single logical value while +\code{multi_predict()} returns a character vector of argument names (or \code{NA} +if none exist). +} +\description{ +\code{has_multi_predict()} tests to see if an object can make multiple +predictions on submodels from the same object. \code{multi_predict_args()} +returns the names of the argments to \code{multi_predict()} for this model +(if any). +} +\examples{ +lm_model_idea <- linear_reg() \%>\% set_engine("lm") +has_multi_predict(lm_model_idea) +lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars) +has_multi_predict(lm_model_fit) + +multi_predict_args(lm_model_fit) + +library(kknn) + +knn_fit <- + nearest_neighbor(mode = "regression", neighbors = 5) \%>\% + set_engine("kknn") \%>\% + fit(mpg ~ ., mtcars) + +multi_predict_args(knn_fit) + +multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred +} +\keyword{internal} diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R index cf92602fc..b295011ea 100644 --- a/tests/testthat/test_boost_tree_C50.R +++ b/tests/testthat/test_boost_tree_C50.R @@ -46,6 +46,9 @@ test_that('C5.0 execution', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "trees") + # outcome is not a factor: expect_error( res <- fit( diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index ed65bcaad..27f15ee9b 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -40,6 +40,9 @@ test_that('xgboost execution, classification', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "trees") + expect_error( res <- parsnip::fit( iris_xgboost, diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index 92506ff7f..46ff3658d 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -25,7 +25,7 @@ test_that('glmnet execution', { skip_if_not_installed("glmnet") expect_error( - fit_xy( + res <- fit_xy( iris_basic, control = ctrl, x = iris[, num_pred], @@ -34,6 +34,9 @@ test_that('glmnet execution', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "penalty") + expect_error( fit( iris_basic, diff --git a/tests/testthat/test_linear_reg_spark.R b/tests/testthat/test_linear_reg_spark.R index 4b7432d80..86b8f3e68 100644 --- a/tests/testthat/test_linear_reg_spark.R +++ b/tests/testthat/test_linear_reg_spark.R @@ -36,6 +36,9 @@ test_that('spark execution', { regexp = NA ) + expect_false(has_multi_predict(spark_fit)) + expect_equal(multi_predict_args(spark_fit), NA_character_) + expect_error( spark_pred <- predict(spark_fit, iris_linreg_te), regexp = NA diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R index e039d10b3..d628fa443 100644 --- a/tests/testthat/test_linear_reg_stan.R +++ b/tests/testthat/test_linear_reg_stan.R @@ -39,6 +39,9 @@ test_that('stan_glm execution', { regexp = NA ) + expect_false(has_multi_predict(res)) + expect_equal(multi_predict_args(res), NA_character_) + expect_error( res <- fit( iris_basic, diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index 74165c6dc..d2db3d5ed 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -26,7 +26,7 @@ test_that('glmnet execution', { skip_if_not_installed("glmnet") expect_error( - fit_xy( + res <- fit_xy( lc_basic, control = ctrl, x = lending_club[, num_pred], @@ -35,6 +35,9 @@ test_that('glmnet execution', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "penalty") + expect_error( glmnet_xy_catch <- fit_xy( lc_basic, diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index ef1b16254..b17d46ac7 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -139,6 +139,9 @@ test_that('mars execution', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "num_terms") + expect_error( res <- fit( iris_basic, diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index 089431941..7843b5ca8 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -6,24 +6,24 @@ context("checking for multi_predict") test_that('parsnip objects', { lm_idea <- linear_reg() %>% set_engine("lm") - expect_false(has_multi_pred(lm_idea)) + expect_false(has_multi_predict(lm_idea)) lm_fit <- fit(lm_idea, mpg ~ ., data = mtcars) - expect_false(has_multi_pred(lm_fit)) - expect_false(has_multi_pred(lm_fit$fit)) + expect_false(has_multi_predict(lm_fit)) + expect_false(has_multi_predict(lm_fit$fit)) mars_fit <- mars(mode = "regression") %>% set_engine("earth") %>% fit(mpg ~ ., data = mtcars) - expect_true(has_multi_pred(mars_fit)) - expect_false(has_multi_pred(mars_fit$fit)) + expect_true(has_multi_predict(mars_fit)) + expect_false(has_multi_predict(mars_fit$fit)) }) test_that('other objects', { - expect_false(has_multi_pred(NULL)) - expect_false(has_multi_pred(NA)) + expect_false(has_multi_predict(NULL)) + expect_false(has_multi_predict(NA)) }) diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index e143c0464..9798811e7 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -32,6 +32,10 @@ test_that('keras execution, classification', { regexp = NA ) + + expect_false(has_multi_predict(res)) + expect_equal(multi_predict_args(res), NA_character_) + keras::backend()$clear_session() expect_error( diff --git a/tests/testthat/test_mlp_nnet.R b/tests/testthat/test_mlp_nnet.R index 4172364b1..61a792b1b 100644 --- a/tests/testthat/test_mlp_nnet.R +++ b/tests/testthat/test_mlp_nnet.R @@ -174,6 +174,10 @@ test_that('multivariate nnet formula', { cbind(V1, V2, V3) ~ ., data = nn_dat[-(1:5),] ) + + expect_false(has_multi_predict(nnet_form)) + expect_equal(multi_predict_args(nnet_form), NA_character_) + expect_equal(length(nnet_form$fit$wts), 24) nnet_form_pred <- predict(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(names(nnet_form_pred), paste0(".pred_", c("V1", "V2", "V3"))) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index 65d774194..4115abbda 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -20,7 +20,7 @@ test_that('glmnet execution', { skip_if_not_installed("glmnet") expect_error( - fit_xy( + res <- fit_xy( multinom_reg() %>% set_engine("glmnet"), control = ctrl, x = iris[, 1:4], @@ -29,6 +29,9 @@ test_that('glmnet execution', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "penalty") + expect_error( glmnet_xy_catch <- fit_xy( multinom_reg() %>% set_engine("glmnet"), diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index b56544ebd..3a0039b70 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -39,7 +39,7 @@ test_that('kknn execution', { # nominal # expect no error expect_error( - fit_xy( + res <- fit_xy( iris_basic, control = ctrl, x = iris[, c("Sepal.Length", "Petal.Width")], @@ -48,6 +48,9 @@ test_that('kknn execution', { regexp = NA ) + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "neighbors") + expect_error( fit( iris_basic, diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 3523613e7..937a82563 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -370,6 +370,9 @@ test_that('ranger classification prediction', { control = ctrl ) + expect_false(has_multi_predict(xy_class_fit)) + expect_equal(multi_predict_args(xy_class_fit), NA_character_) + xy_class_pred <- predict(xy_class_fit$fit, data = iris[c(1, 51, 101), 1:4])$prediction xy_class_pred <- colnames(xy_class_pred)[apply(xy_class_pred, 1, which.max)] xy_class_pred <- factor(xy_class_pred, levels = levels(iris$Species)) diff --git a/tests/testthat/test_surv_reg_flexsurv.R b/tests/testthat/test_surv_reg_flexsurv.R index f2985464a..85f3051a4 100644 --- a/tests/testthat/test_surv_reg_flexsurv.R +++ b/tests/testthat/test_surv_reg_flexsurv.R @@ -37,6 +37,9 @@ test_that('flexsurv execution', { ), regexp = NA ) + expect_false(has_multi_predict(res)) + expect_equal(multi_predict_args(res), NA_character_) + expect_error( res <- fit_xy( surv_basic, diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index bb5b2a65c..3f0cdc9ac 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -115,7 +115,7 @@ test_that('svm poly regression', { skip_if_not_installed("kernlab") expect_error( - fit_xy( + res <- fit_xy( reg_mod, control = ctrl, x = iris[,2:4], @@ -124,6 +124,9 @@ test_that('svm poly regression', { regexp = NA ) + expect_false(has_multi_predict(res)) + expect_equal(multi_predict_args(res), NA_character_) + expect_error( fit( reg_mod, diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index 41523e0a7..088f3cd32 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -94,7 +94,7 @@ test_that('svm poly regression', { skip_if_not_installed("kernlab") expect_error( - fit_xy( + res <- fit_xy( reg_mod, control = ctrl, x = iris[,2:4], @@ -102,6 +102,8 @@ test_that('svm poly regression', { ), regexp = NA ) + expect_false(has_multi_predict(res)) + expect_equal(multi_predict_args(res), NA_character_) expect_error( fit(