Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
S3method(multi_predict,"_lognet")
S3method(multi_predict,"_multnet")
S3method(multi_predict,"_train.kknn")
S3method(multi_predict,"_xgb.Booster")
S3method(multi_predict,default)
S3method(nullmodel,default)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* `add_rowindex()` can create a column called `.row` to a data frame.

* If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.
* `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.


# parsnip 0.0.2
Expand Down
3 changes: 2 additions & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
#' @importFrom utils globalVariables
utils::globalVariables(
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type')
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
"neighbors")
)
35 changes: 35 additions & 0 deletions R/aaa_multi_predict.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Define a generic to make multiple predictions for the same model object ------

#' Model predictions across many sub-models
#'
#' For some models, predictions can be made on sub-models in the model object.
#' @param object A `model_fit` object.
#' @param new_data A rectangular data object, such as a data frame.
#' @param type A single character value or `NULL`. Possible values
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
#' based on the model's mode.
#' @param ... Optional arguments to pass to `predict.model_fit(type = "raw")`
#' such as `type`.
#' @return A tibble with the same number of rows as the data being predicted.
#' Mostly likely, there is a list-column named `.pred` that is a tibble with
#' multiple rows per sub-model.
#' @export
multi_predict <- function(object, ...) {
if (inherits(object$fit, "try-error")) {
warning("Model fit failed; cannot make predictions.", call. = FALSE)
return(NULL)
}
UseMethod("multi_predict")
}

#' @export
#' @rdname multi_predict
multi_predict.default <- function(object, ...)
stop("No `multi_predict` method exists for objects with classes ",
paste0("'", class(), "'", collapse = ", "), call. = FALSE)

#' @export
predict.model_spec <- function(object, ...) {
stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE)
}
3 changes: 3 additions & 0 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ xgb_pred <- function(object, newdata, ...) {

#' @importFrom purrr map_df
#' @export
#' @rdname multi_predict
#' @param trees An integer vector for the number of trees in the ensemble.
multi_predict._xgb.Booster <-
function(object, new_data, type = NULL, trees = NULL, ...) {
if (any(names(enquos(...)) == "newdata")) {
Expand Down Expand Up @@ -474,6 +476,7 @@ C5.0_train <-
}

#' @export
#' @rdname multi_predict
multi_predict._C5.0 <-
function(object, new_data, type = NULL, trees = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
Expand Down
2 changes: 2 additions & 0 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
#' @importFrom dplyr full_join as_tibble arrange
#' @importFrom tidyr gather
#' @export
#'@rdname multi_predict
#' @param penalty An numeric vector of penalty values.
multi_predict._elnet <-
function(object, new_data, type = NULL, penalty = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
Expand Down
1 change: 1 addition & 0 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ predict._lognet <- function (object, new_data, type = NULL, opts = list(), penal
#' @importFrom dplyr full_join as_tibble arrange
#' @importFrom tidyr gather
#' @export
#' @rdname multi_predict
multi_predict._lognet <-
function(object, new_data, type = NULL, penalty = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
Expand Down
2 changes: 2 additions & 0 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ earth_reg_updater <- function(num, object, new_data, ...) {

#' @importFrom purrr map_df
#' @importFrom dplyr arrange
#' @rdname multi_predict
#' @param num_terms An integer vector for the number of MARS terms to retain.
#' @export
multi_predict._earth <-
function(object, new_data, type = NULL, num_terms = NULL, ...) {
Expand Down
1 change: 1 addition & 0 deletions R/multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ predict._multnet <-
#' @importFrom dplyr full_join as_tibble arrange
#' @importFrom tidyr gather
#' @export
#' @rdname multi_predict
multi_predict._multnet <-
function(object, new_data, type = NULL, penalty = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
Expand Down
40 changes: 40 additions & 0 deletions R/nearest_neighbor.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,43 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
}
x
}


# ------------------------------------------------------------------------------

#' @importFrom purrr map_df
#' @importFrom dplyr starts_with
#' @rdname multi_predict
#' @param neighbors An integer vector for the number of nearest neighbors.
#' @export
multi_predict._train.kknn <-
function(object, new_data, type = NULL, neighbors = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

if (is.null(neighbors))
neighbors <- rlang::eval_tidy(object$fit$call$ks)
neighbors <- sort(neighbors)

if (is.null(type)) {
if (object$spec$mode == "classification")
type <- "class"
else
type <- "numeric"
}

res <-
purrr::map_df(neighbors, knn_by_k, object = object,
new_data = new_data, type = type, ...)
res <- dplyr::arrange(res, .row, neighbors)
res <- split(res[, -1], res$.row)
names(res) <- NULL
dplyr::tibble(.pred = res)
}

knn_by_k <- function(k, object, new_data, type, ...) {
object$fit$call$ks <- k
predict(object, new_data = new_data, type = type, ...) %>%
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
}
1 change: 0 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -301,4 +301,3 @@ has_multi_pred.workflow <- function(object, ...) {
has_multi_pred(object$fit$model$model)
}


47 changes: 46 additions & 1 deletion man/multi_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

75 changes: 75 additions & 0 deletions tests/testthat/test_nearest_neighbor_kknn.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,78 @@ test_that('kknn prediction', {

expect_equal(form_pred, predict(res_form, iris[1:5, c("Sepal.Width", "Species")])$.pred)
})


test_that('kknn multi-predict', {

skip_if_not_installed("kknn")
library(kknn)

iris_te <- c(1:2, 50:51, 100:101)
k_vals <- 1:10

res_xy <- fit_xy(
nearest_neighbor(mode = "classification", neighbors = 3) %>%
set_engine("kknn"),
control = ctrl,
x = iris[-iris_te, num_pred],
y = iris$Species[-iris_te]
)

pred_multi <- multi_predict(res_xy, iris[iris_te, num_pred], neighbors = k_vals)
expect_equal(pred_multi %>% unnest() %>% nrow(), length(iris_te) * length(k_vals))
expect_equal(pred_multi %>% nrow(), length(iris_te))

pred_uni <- predict(res_xy, iris[iris_te, num_pred])
pred_uni_obs <-
pred_multi %>%
mutate(.rows = row_number()) %>%
unnest() %>%
dplyr::filter(neighbors == 3) %>%
arrange(.rows) %>%
dplyr::select(.pred_class)
expect_equal(pred_uni, pred_uni_obs)


prob_multi <- multi_predict(res_xy, iris[iris_te, num_pred],
neighbors = k_vals, type = "prob")
expect_equal(prob_multi %>% unnest() %>% nrow(), length(iris_te) * length(k_vals))
expect_equal(prob_multi %>% nrow(), length(iris_te))

prob_uni <- predict(res_xy, iris[iris_te, num_pred], type = "prob")
prob_uni_obs <-
prob_multi %>%
mutate(.rows = row_number()) %>%
unnest() %>%
dplyr::filter(neighbors == 3) %>%
arrange(.rows) %>%
dplyr::select(!!names(prob_uni))
expect_equal(prob_uni, prob_uni_obs)

# ----------------------------------------------------------------------------
# regression

cars_te <- 1:5
k_vals <- 1:10

res_xy <- fit(
nearest_neighbor(mode = "regression", neighbors = 3) %>%
set_engine("kknn"),
control = ctrl,
mpg ~ ., data = mtcars[-cars_te, ]
)

pred_multi <- multi_predict(res_xy, mtcars[cars_te, -1], neighbors = k_vals)
expect_equal(pred_multi %>% unnest() %>% nrow(), length(cars_te) * length(k_vals))
expect_equal(pred_multi %>% nrow(), length(cars_te))

pred_uni <- predict(res_xy, mtcars[cars_te, -1])
pred_uni_obs <-
pred_multi %>%
mutate(.rows = row_number()) %>%
unnest() %>%
dplyr::filter(neighbors == 3) %>%
arrange(.rows) %>%
dplyr::select(.pred)
expect_equal(pred_uni, pred_uni_obs)
})