Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

S3method(fit,model_spec)
S3method(fit_xy,model_spec)
S3method(has_multi_pred,default)
S3method(has_multi_pred,model_fit)
S3method(has_multi_pred,workflow)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
S3method(has_multi_predict,workflow)
S3method(multi_predict,"_C5.0")
S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
Expand All @@ -13,6 +13,9 @@ S3method(multi_predict,"_multnet")
S3method(multi_predict,"_train.kknn")
S3method(multi_predict,"_xgb.Booster")
S3method(multi_predict,default)
S3method(multi_predict_args,default)
S3method(multi_predict_args,model_fit)
S3method(multi_predict_args,workflow)
S3method(nullmodel,default)
S3method(predict,"_elnet")
S3method(predict,"_lognet")
Expand Down Expand Up @@ -95,7 +98,7 @@ export(get_fit)
export(get_from_env)
export(get_model_env)
export(get_pred_type)
export(has_multi_pred)
export(has_multi_predict)
export(keras_mlp)
export(linear_reg)
export(logistic_reg)
Expand All @@ -104,6 +107,7 @@ export(mars)
export(mlp)
export(model_printer)
export(multi_predict)
export(multi_predict_args)
export(multinom_reg)
export(nearest_neighbor)
export(null_model)
Expand Down
99 changes: 99 additions & 0 deletions R/aaa_multi_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,102 @@ multi_predict.default <- function(object, ...)
predict.model_spec <- function(object, ...) {
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
}

#' Tools for models that predict on sub-models
#'
#' `has_multi_predict()` tests to see if an object can make multiple
#' predictions on submodels from the same object. `multi_predict_args()`
#' returns the names of the argments to `multi_predict()` for this model
#' (if any).
#' @param object An object to test.
#' @param ... Not currently used.
#' @return `has_multi_predict()` returns single logical value while
#' `multi_predict()` returns a character vector of argument names (or `NA`
#' if none exist).
#' @keywords internal
#' @examples
#' lm_model_idea <- linear_reg() %>% set_engine("lm")
#' has_multi_predict(lm_model_idea)
#' lm_model_fit <- fit(lm_model_idea, mpg ~ ., data = mtcars)
#' has_multi_predict(lm_model_fit)
#'
#' multi_predict_args(lm_model_fit)
#'
#' library(kknn)
#'
#' knn_fit <-
#' nearest_neighbor(mode = "regression", neighbors = 5) %>%
#' set_engine("kknn") %>%
#' fit(mpg ~ ., mtcars)
#'
#' multi_predict_args(knn_fit)
#'
#' multi_predict(knn_fit, mtcars[1, -1], neighbors = 1:4)$.pred
#' @importFrom utils methods
#' @export
has_multi_predict <- function(object, ...) {
UseMethod("has_multi_predict")
}

#' @export
#' @rdname has_multi_predict
has_multi_predict.default <- function(object, ...) {
FALSE
}

#' @export
#' @rdname has_multi_predict
has_multi_predict.model_fit <- function(object, ...) {
existing_mthds <- utils::methods("multi_predict")
tst <- paste0("multi_predict.", class(object))
any(tst %in% existing_mthds)
}

#' @export
#' @rdname has_multi_predict
has_multi_predict.workflow <- function(object, ...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this method should live in workflows? Can you even test this function in parsnip without suggesting workflows? It's always hard to know where to put what.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe has_multi_predict.workflow() should but the other bits should probably stay here. The workflow method is more of a place holder for when that package is more mature

It's always hard to know where to put what.

Yeah. That is becoming increasingly true.

has_multi_predict(object$fit$model$model)
}


#' @rdname has_multi_predict
#' @export
#' @rdname has_multi_predict
multi_predict_args <- function(object, ...) {
UseMethod("multi_predict_args")
}

#' @export
#' @rdname has_multi_predict
multi_predict_args.default <- function(object, ...) {
if (inherits(object, "model_fit")) {
res <- multi_predict_args.model_fit(object, ...)
} else {
res <- NA_character_
}
res
}

#' @export
#' @rdname has_multi_predict
multi_predict_args.model_fit <- function(object, ...) {
existing_mthds <- methods("multi_predict")
cls <- class(object)
tst <- paste0("multi_predict.", cls)
.fn <- tst[tst %in% existing_mthds]
if (length(.fn) == 0) {
return(NA_character_)
}

.fn <- getFromNamespace(.fn, ns = "parsnip")
omit <- c('object', 'new_data', 'type', '...')
args <- names(formals(.fn))
args[!(args %in% omit)]
}

#' @export
#' @rdname has_multi_predict
multi_predict_args.workflow <- function(object, ...) {
object <- object$fit$model$model

}
2 changes: 1 addition & 1 deletion R/nearest_neighbor_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ set_model_arg(
parsnip = "neighbors",
original = "ks",
func = list(pkg = "dials", fun = "neighbors"),
has_submodel = FALSE
has_submodel = TRUE
)
set_model_arg(
model = "nearest_neighbor",
Expand Down
70 changes: 0 additions & 70 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,73 +231,3 @@ prepare_data <- function(object, new_data) {
new_data
}

# Define a generic to make multiple predictions for the same model object ------

#' Model predictions across many sub-models
#'
#' For some models, predictions can be made on sub-models in the model object.
#' @param object A `model_fit` object.
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
#' such as `type`.
#' @return A tibble with the same number of rows as the data being predicted.
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
#' multiple rows per sub-model.
#' @export
multi_predict <- function(object, ...) {
if (inherits(object$fit, "try-error")) {
warning("Model fit failed; cannot make predictions.", call. = FALSE)
return(NULL)
}
UseMethod("multi_predict")
}

#' @export
#' @rdname multi_predict
multi_predict.default <- function(object, ...)
stop("No `multi_predict` method exists for objects with classes ",
paste0("'", class(), "'", collapse = ", "), call. = FALSE)

#' @export
predict.model_spec <- function(object, ...) {
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
}



#' Determine if a model can make predictions on sub-models
#'
#' @param object An object to test.
#' @param ... Not currently used.
#' @return A single logical value.
#' @keywords internal
#' @examples
#' model_idea <- linear_reg() %>% set_engine("lm")
#' has_multi_pred(model_idea)
#' model_fit <- fit(model_idea, mpg ~ ., data = mtcars)
#' has_multi_pred(model_fit)
#' @importFrom utils methods
#' @export
has_multi_pred <- function(object, ...) {
UseMethod("has_multi_pred")
}

#' @export
#' @rdname has_multi_pred
has_multi_pred.default <- function(object, ...) {
FALSE
}

#' @export
#' @rdname has_multi_pred
has_multi_pred.model_fit <- function(object, ...) {
existing_mthds <- utils::methods("multi_predict")
tst <- paste0("multi_predict.", class(object))
any(tst %in% existing_mthds)
}

#' @export
#' @rdname has_multi_pred
has_multi_pred.workflow <- function(object, ...) {
has_multi_pred(object$fit$model$model)
}

35 changes: 0 additions & 35 deletions man/has_multi_pred.Rd

This file was deleted.

65 changes: 65 additions & 0 deletions man/has_multi_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tests/testthat/test_boost_tree_C50.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ test_that('C5.0 execution', {
regexp = NA
)

expect_true(has_multi_predict(res))
expect_equal(multi_predict_args(res), "trees")

# outcome is not a factor:
expect_error(
res <- fit(
Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_boost_tree_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ test_that('xgboost execution, classification', {
regexp = NA
)

expect_true(has_multi_predict(res))
expect_equal(multi_predict_args(res), "trees")

expect_error(
res <- parsnip::fit(
iris_xgboost,
Expand Down
5 changes: 4 additions & 1 deletion tests/testthat/test_linear_reg_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_that('glmnet execution', {
skip_if_not_installed("glmnet")

expect_error(
fit_xy(
res <- fit_xy(
iris_basic,
control = ctrl,
x = iris[, num_pred],
Expand All @@ -34,6 +34,9 @@ test_that('glmnet execution', {
regexp = NA
)

expect_true(has_multi_predict(res))
expect_equal(multi_predict_args(res), "penalty")

expect_error(
fit(
iris_basic,
Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_linear_reg_spark.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ test_that('spark execution', {
regexp = NA
)

expect_false(has_multi_predict(spark_fit))
expect_equal(multi_predict_args(spark_fit), NA_character_)

expect_error(
spark_pred <- predict(spark_fit, iris_linreg_te),
regexp = NA
Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_linear_reg_stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ test_that('stan_glm execution', {
regexp = NA
)

expect_false(has_multi_predict(res))
expect_equal(multi_predict_args(res), NA_character_)

expect_error(
res <- fit(
iris_basic,
Expand Down
Loading