diff --git a/NAMESPACE b/NAMESPACE index 87f703ad2..ba4ce863d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,7 @@ S3method(multi_predict,"_earth") S3method(multi_predict,"_elnet") S3method(multi_predict,"_lognet") S3method(multi_predict,"_multnet") +S3method(multi_predict,"_train.kknn") S3method(multi_predict,"_xgb.Booster") S3method(multi_predict,default) S3method(nullmodel,default) diff --git a/NEWS.md b/NEWS.md index b135f4295..53875047a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,7 @@ * `add_rowindex()` can create a column called `.row` to a data frame. * If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero. + * `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized. # parsnip 0.0.2 diff --git a/R/aaa.R b/R/aaa.R index 230c2b08c..a49415299 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -23,5 +23,6 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { #' @importFrom utils globalVariables utils::globalVariables( c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', - 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type') + 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type', + "neighbors") ) diff --git a/R/aaa_multi_predict.R b/R/aaa_multi_predict.R new file mode 100644 index 000000000..400cdf199 --- /dev/null +++ b/R/aaa_multi_predict.R @@ -0,0 +1,35 @@ +# Define a generic to make multiple predictions for the same model object ------ + +#' Model predictions across many sub-models +#' +#' For some models, predictions can be made on sub-models in the model object. +#' @param object A `model_fit` object. +#' @param new_data A rectangular data object, such as a data frame. +#' @param type A single character value or `NULL`. Possible values +#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", +#' or "raw". When `NULL`, `predict()` will choose an appropriate value +#' based on the model's mode. +#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")` +#' such as `type`. +#' @return A tibble with the same number of rows as the data being predicted. +#' Mostly likely, there is a list-column named `.pred` that is a tibble with +#' multiple rows per sub-model. +#' @export +multi_predict <- function(object, ...) { + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + UseMethod("multi_predict") +} + +#' @export +#' @rdname multi_predict +multi_predict.default <- function(object, ...) + stop("No `multi_predict` method exists for objects with classes ", + paste0("'", class(), "'", collapse = ", "), call. = FALSE) + +#' @export +predict.model_spec <- function(object, ...) { + stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE) +} diff --git a/R/boost_tree.R b/R/boost_tree.R index 97a042ba1..2963de1a4 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -366,6 +366,8 @@ xgb_pred <- function(object, newdata, ...) { #' @importFrom purrr map_df #' @export +#' @rdname multi_predict +#' @param trees An integer vector for the number of trees in the ensemble. multi_predict._xgb.Booster <- function(object, new_data, type = NULL, trees = NULL, ...) { if (any(names(enquos(...)) == "newdata")) { @@ -474,6 +476,7 @@ C5.0_train <- } #' @export +#' @rdname multi_predict multi_predict._C5.0 <- function(object, new_data, type = NULL, trees = NULL, ...) { if (any(names(enquos(...)) == "newdata")) diff --git a/R/linear_reg.R b/R/linear_reg.R index d3bc72c70..f9e5a8e74 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -302,6 +302,8 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) { #' @importFrom dplyr full_join as_tibble arrange #' @importFrom tidyr gather #' @export +#'@rdname multi_predict +#' @param penalty An numeric vector of penalty values. multi_predict._elnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { if (any(names(enquos(...)) == "newdata")) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 1bc9062b6..7264680aa 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -276,6 +276,7 @@ predict._lognet <- function (object, new_data, type = NULL, opts = list(), penal #' @importFrom dplyr full_join as_tibble arrange #' @importFrom tidyr gather #' @export +#' @rdname multi_predict multi_predict._lognet <- function(object, new_data, type = NULL, penalty = NULL, ...) { if (any(names(enquos(...)) == "newdata")) diff --git a/R/mars.R b/R/mars.R index bfe7f6cbf..f83e56e29 100644 --- a/R/mars.R +++ b/R/mars.R @@ -205,6 +205,8 @@ earth_reg_updater <- function(num, object, new_data, ...) { #' @importFrom purrr map_df #' @importFrom dplyr arrange +#' @rdname multi_predict +#' @param num_terms An integer vector for the number of MARS terms to retain. #' @export multi_predict._earth <- function(object, new_data, type = NULL, num_terms = NULL, ...) { diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 71e367489..b8bc0a479 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -232,6 +232,7 @@ predict._multnet <- #' @importFrom dplyr full_join as_tibble arrange #' @importFrom tidyr gather #' @export +#' @rdname multi_predict multi_predict._multnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { if (any(names(enquos(...)) == "newdata")) diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 4f0cbb165..b28a3df46 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -178,3 +178,43 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) { } x } + + +# ------------------------------------------------------------------------------ + +#' @importFrom purrr map_df +#' @importFrom dplyr starts_with +#' @rdname multi_predict +#' @param neighbors An integer vector for the number of nearest neighbors. +#' @export +multi_predict._train.kknn <- + function(object, new_data, type = NULL, neighbors = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + if (is.null(neighbors)) + neighbors <- rlang::eval_tidy(object$fit$call$ks) + neighbors <- sort(neighbors) + + if (is.null(type)) { + if (object$spec$mode == "classification") + type <- "class" + else + type <- "numeric" + } + + res <- + purrr::map_df(neighbors, knn_by_k, object = object, + new_data = new_data, type = type, ...) + res <- dplyr::arrange(res, .row, neighbors) + res <- split(res[, -1], res$.row) + names(res) <- NULL + dplyr::tibble(.pred = res) + } + +knn_by_k <- function(k, object, new_data, type, ...) { + object$fit$call$ks <- k + predict(object, new_data = new_data, type = type, ...) %>% + dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>% + dplyr::select(.row, neighbors, dplyr::starts_with(".pred")) +} diff --git a/R/predict.R b/R/predict.R index c01c7f5e2..8339d886c 100644 --- a/R/predict.R +++ b/R/predict.R @@ -301,4 +301,3 @@ has_multi_pred.workflow <- function(object, ...) { has_multi_pred(object$fit$model$model) } - diff --git a/man/multi_predict.Rd b/man/multi_predict.Rd index c12d9ee7a..1dab63d59 100644 --- a/man/multi_predict.Rd +++ b/man/multi_predict.Rd @@ -1,19 +1,64 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict.R +% Please edit documentation in R/aaa_multi_predict.R, R/boost_tree.R, +% R/linear_reg.R, R/logistic_reg.R, R/mars.R, R/multinom_reg.R, +% R/nearest_neighbor.R \name{multi_predict} \alias{multi_predict} \alias{multi_predict.default} +\alias{multi_predict._xgb.Booster} +\alias{multi_predict._C5.0} +\alias{multi_predict._elnet} +\alias{multi_predict._lognet} +\alias{multi_predict._earth} +\alias{multi_predict._multnet} +\alias{multi_predict._train.kknn} \title{Model predictions across many sub-models} \usage{ multi_predict(object, ...) \method{multi_predict}{default}(object, ...) + +\method{multi_predict}{_xgb.Booster}(object, new_data, type = NULL, + trees = NULL, ...) + +\method{multi_predict}{_C5.0}(object, new_data, type = NULL, + trees = NULL, ...) + +\method{multi_predict}{_elnet}(object, new_data, type = NULL, + penalty = NULL, ...) + +\method{multi_predict}{_lognet}(object, new_data, type = NULL, + penalty = NULL, ...) + +\method{multi_predict}{_earth}(object, new_data, type = NULL, + num_terms = NULL, ...) + +\method{multi_predict}{_multnet}(object, new_data, type = NULL, + penalty = NULL, ...) + +\method{multi_predict}{_train.kknn}(object, new_data, type = NULL, + neighbors = NULL, ...) } \arguments{ \item{object}{A \code{model_fit} object.} \item{...}{Optional arguments to pass to \code{predict.model_fit(type = "raw")} such as \code{type}.} + +\item{new_data}{A rectangular data object, such as a data frame.} + +\item{type}{A single character value or \code{NULL}. Possible values +are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", +or "raw". When \code{NULL}, \code{predict()} will choose an appropriate value +based on the model's mode.} + +\item{trees}{An integer vector for the number of trees in the ensemble.} + +\item{penalty}{An numeric vector of penalty values.} + +\item{num_terms}{An integer vector for the number of MARS terms to retain.} + +\item{neighbors}{An integer vector for the number of nearest neighbors.} } \value{ A tibble with the same number of rows as the data being predicted. diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 37ea2e262..b56544ebd 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -113,3 +113,78 @@ test_that('kknn prediction', { expect_equal(form_pred, predict(res_form, iris[1:5, c("Sepal.Width", "Species")])$.pred) }) + + +test_that('kknn multi-predict', { + + skip_if_not_installed("kknn") + library(kknn) + + iris_te <- c(1:2, 50:51, 100:101) + k_vals <- 1:10 + + res_xy <- fit_xy( + nearest_neighbor(mode = "classification", neighbors = 3) %>% + set_engine("kknn"), + control = ctrl, + x = iris[-iris_te, num_pred], + y = iris$Species[-iris_te] + ) + + pred_multi <- multi_predict(res_xy, iris[iris_te, num_pred], neighbors = k_vals) + expect_equal(pred_multi %>% unnest() %>% nrow(), length(iris_te) * length(k_vals)) + expect_equal(pred_multi %>% nrow(), length(iris_te)) + + pred_uni <- predict(res_xy, iris[iris_te, num_pred]) + pred_uni_obs <- + pred_multi %>% + mutate(.rows = row_number()) %>% + unnest() %>% + dplyr::filter(neighbors == 3) %>% + arrange(.rows) %>% + dplyr::select(.pred_class) + expect_equal(pred_uni, pred_uni_obs) + + + prob_multi <- multi_predict(res_xy, iris[iris_te, num_pred], + neighbors = k_vals, type = "prob") + expect_equal(prob_multi %>% unnest() %>% nrow(), length(iris_te) * length(k_vals)) + expect_equal(prob_multi %>% nrow(), length(iris_te)) + + prob_uni <- predict(res_xy, iris[iris_te, num_pred], type = "prob") + prob_uni_obs <- + prob_multi %>% + mutate(.rows = row_number()) %>% + unnest() %>% + dplyr::filter(neighbors == 3) %>% + arrange(.rows) %>% + dplyr::select(!!names(prob_uni)) + expect_equal(prob_uni, prob_uni_obs) + + # ---------------------------------------------------------------------------- + # regression + + cars_te <- 1:5 + k_vals <- 1:10 + + res_xy <- fit( + nearest_neighbor(mode = "regression", neighbors = 3) %>% + set_engine("kknn"), + control = ctrl, + mpg ~ ., data = mtcars[-cars_te, ] + ) + + pred_multi <- multi_predict(res_xy, mtcars[cars_te, -1], neighbors = k_vals) + expect_equal(pred_multi %>% unnest() %>% nrow(), length(cars_te) * length(k_vals)) + expect_equal(pred_multi %>% nrow(), length(cars_te)) + + pred_uni <- predict(res_xy, mtcars[cars_te, -1]) + pred_uni_obs <- + pred_multi %>% + mutate(.rows = row_number()) %>% + unnest() %>% + dplyr::filter(neighbors == 3) %>% + arrange(.rows) %>% + dplyr::select(.pred) + expect_equal(pred_uni, pred_uni_obs) +})