diff --git a/NAMESPACE b/NAMESPACE index 70f252d81..ecb1212a3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,19 +17,13 @@ S3method(predict,model_fit) S3method(predict,model_spec) S3method(predict,nullmodel) S3method(predict_class,"_lognet") -S3method(predict_class,model_fit) +S3method(predict_class,"_multnet") S3method(predict_classprob,"_lognet") S3method(predict_classprob,"_multnet") -S3method(predict_classprob,model_fit) -S3method(predict_confint,model_fit) S3method(predict_numeric,"_elnet") -S3method(predict_numeric,model_fit) -S3method(predict_predint,model_fit) -S3method(predict_quantile,model_fit) S3method(predict_raw,"_elnet") S3method(predict_raw,"_lognet") S3method(predict_raw,"_multnet") -S3method(predict_raw,model_fit) S3method(print,boost_tree) S3method(print,decision_tree) S3method(print,fit_control) @@ -103,20 +97,6 @@ export(nearest_neighbor) export(null_model) export(nullmodel) export(predict.model_fit) -export(predict_class) -export(predict_class.model_fit) -export(predict_classprob) -export(predict_classprob.model_fit) -export(predict_confint) -export(predict_confint.model_fit) -export(predict_numeric) -export(predict_numeric.model_fit) -export(predict_predint) -export(predict_predint.model_fit) -export(predict_quantile) -export(predict_quantile.model_fit) -export(predict_raw) -export(predict_raw.model_fit) export(rand_forest) export(rpart_train) export(set_args) diff --git a/NEWS.md b/NEWS.md index 775ee3b28..16b00685d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ ## New Features * A "null model" is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification). + * `fit_xy()` can take a single column data frame or matrix for `y` without error ## Other Changes @@ -13,6 +14,8 @@ that are actually varying). * `fit_control()` not returns an S3 method. +* The prediction modules (e.g. `predict_class`, `predict_numeric`, etc) were de-exported. These were internal functions that were not to be used by the users and the users were using them. + ## Bug Fixes * `varying_args()` now uses the version from the `generics` package. This means @@ -33,6 +36,7 @@ column names once (#107). * For multinomial regression using glmnet, `multi_predict()` now pulls the correct default penalty (#108). +* Confidence and prediction intervals for logistic regression were only computed the intervals for a single level. Both are now computed. (#156) # parsnip 0.0.1 diff --git a/R/linear_reg.R b/R/linear_reg.R index 456f4bc9c..91ed2d125 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -63,20 +63,20 @@ #' \pkg{spark} #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} -#' +#' #' \pkg{keras} #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. -#' This can have an effect on the model object results. When using -#' the `predict()` method in these cases, the return object type -#' depends on the value of `penalty`. If a single value is -#' given, the results will be a simple numeric vector. When -#' multiple values or no values for `penalty` are used in -#' `linear_reg()`, the `predict()` method will return a data frame with -#' columns `values` and `lambda`. +#' multiple values (or no values) to the `penalty` argument. This +#' can have an effect on the model object results. When using the +#' `predict()` method in these cases, the return value depends on +#' the value of `penalty`. When using `predict()`, only a single +#' value of the penalty can be used. When predicting on multiple +#' penalties, the `multi_predict()` function can be used. It +#' returns a tibble with a list column called `.pred` that contains +#' a tibble with all of the penalty results. #' #' For prediction, the `stan` engine can compute posterior #' intervals analogous to confidence and prediction intervals. In @@ -130,7 +130,7 @@ print.linear_reg <- function(x, ...) { cat("Linear Regression Model Specification (", x$mode, ")\n\n", sep = "") model_printer(x, ...) - if(!is.null(x$method$fit$args)) { + if (!is.null(x$method$fit$args)) { cat("Model fit template:\n") print(show_call(x)) } @@ -216,12 +216,66 @@ organize_glmnet_pred <- function(x, object) { # ------------------------------------------------------------------------------ +# For `predict` methods that use `glmnet`, we have specific methods. +# Only one value of the penalty should be allowed when called by `predict()`: + +check_penalty <- function(penalty = NULL, object, multi = FALSE) { + + if (is.null(penalty)) { + penalty <- object$fit$lambda + } + + # when using `predict()`, allow for a single lambda + if (!multi) { + if (length(penalty) != 1) + stop("`penalty` should be a single numeric value. ", + "`multi_predict()` can be used to get multiple predictions ", + "per row of data.", call. = FALSE) + } + + if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) + stop("The glmnet model was fit with a single penalty value of ", + object$fit$lambda, ". Predicting with a value of ", + penalty, " will give incorrect results from `glmnet()`.", + call. = FALSE) + + penalty +} + +# ------------------------------------------------------------------------------ +# glmnet call stack for linear regression using `predict` when object has +# classes "_elnet" and "model_fit": +# +# predict() +# predict._elnet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_numeric() +# predict_numeric._elnet() +# predict_numeric.model_fit() +# predict.elnet() + + +# glmnet call stack for linear regression using `multi_predict` when object has +# classes "_elnet" and "model_fit": +# +# multi_predict() +# multi_predict._elnet(penalty = NULL) +# predict._elnet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._elnet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.elnet() + + #' @export predict._elnet <- - function(object, new_data, type = NULL, opts = list(), ...) { + function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + + object$spec$args$penalty <- check_penalty(penalty, object, multi) + object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } @@ -230,7 +284,7 @@ predict._elnet <- predict_numeric._elnet <- function(object, new_data, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + object$spec <- eval_args(object$spec) predict_numeric.model_fit(object, new_data = new_data, ...) } @@ -239,8 +293,9 @@ predict_numeric._elnet <- function(object, new_data, ...) { predict_raw._elnet <- function(object, new_data, opts = list(), ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + object$spec <- eval_args(object$spec) + opts$s <- object$spec$args$penalty predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -251,14 +306,17 @@ multi_predict._elnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + dots <- list(...) - if (is.null(penalty)) - penalty <- object$fit$lambda - dots$s <- penalty object$spec <- eval_args(object$spec) - pred <- predict(object, new_data = new_data, type = "raw", opts = dots) + + if (is.null(penalty)) { + penalty <- object$fit$lambda + } + + pred <- predict._elnet(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) param_key <- tibble(group = colnames(pred), penalty = penalty) pred <- as_tibble(pred) pred$.row <- 1:nrow(pred) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 10a4df1db..1c5a76d0b 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -67,14 +67,14 @@ #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. -#' This can have an effect on the model object results. When using -#' the `predict()` method in these cases, the return object type -#' depends on the value of `penalty`. If a single value is -#' given, the results will be a simple numeric vector. When -#' multiple values or no values for `penalty` are used in -#' `logistic_reg()`, the `predict()` method will return a data frame with -#' columns `values` and `lambda`. +#' multiple values (or no values) to the `penalty` argument. This +#' can have an effect on the model object results. When using the +#' `predict()` method in these cases, the return value depends on +#' the value of `penalty`. When using `predict()`, only a single +#' value of the penalty can be used. When predicting on multiple +#' penalties, the `multi_predict()` function can be used. It +#' returns a tibble with a list column called `.pred` that contains +#' a tibble with all of the penalty results. #' #' For prediction, the `stan` engine can compute posterior #' intervals analogous to confidence and prediction intervals. In @@ -235,41 +235,41 @@ organize_glmnet_prob <- function(x, object) { } # ------------------------------------------------------------------------------ +# glmnet call stack for linear regression using `predict` when object has +# classes "_lognet" and "model_fit" (for class predictions): +# +# predict() +# predict._lognet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_class() +# predict_class._lognet() +# predict_class.model_fit() +# predict.lognet() + + +# glmnet call stack for linear regression using `multi_predict` when object has +# classes "_lognet" and "model_fit" (for class predictions): +# +# multi_predict() +# multi_predict._lognet(penalty = NULL) +# predict._lognet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._lognet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.lognet() -#' @export -predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) - stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) -} - -#' @export -predict_class._lognet <- function (object, new_data, ...) { - if (any(names(enquos(...)) == "newdata")) - stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} +# ------------------------------------------------------------------------------ #' @export -predict_classprob._lognet <- function (object, new_data, ...) { +predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} - -#' @export -predict_raw._lognet <- function (object, new_data, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) - stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec$args$penalty <- check_penalty(penalty, object, multi) object$spec <- eval_args(object$spec) - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) + predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } @@ -281,15 +281,18 @@ multi_predict._lognet <- if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is_quosure(penalty)) + penalty <- eval_tidy(penalty) + dots <- list(...) if (is.null(penalty)) - penalty <- object$fit$lambda + penalty <- eval_tidy(object$fit$lambda) dots$s <- penalty if (is.null(type)) type <- "class" - if (!(type %in% c("class", "prob", "link"))) { - stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE) + if (!(type %in% c("class", "prob", "link", "raw"))) { + stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE) } if (type == "prob") dots$type <- "response" @@ -297,7 +300,7 @@ multi_predict._lognet <- dots$type <- type object$spec <- eval_args(object$spec) - pred <- predict(object, new_data = new_data, type = "raw", opts = dots) + pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots) param_key <- tibble(group = colnames(pred), penalty = penalty) pred <- as_tibble(pred) pred$.row <- 1:nrow(pred) @@ -321,6 +324,38 @@ multi_predict._lognet <- tibble(.pred = pred) } + + + + +#' @export +predict_class._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + object$spec <- eval_args(object$spec) + predict_class.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_classprob._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + object$spec <- eval_args(object$spec) + predict_classprob.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_raw._lognet <- function (object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + object$spec <- eval_args(object$spec) + predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) +} + + # ------------------------------------------------------------------------------ #' @importFrom utils globalVariables diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index a5aef8bfb..6fa1e0c94 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -4,7 +4,7 @@ logistic_reg_arg_key <- data.frame( glmnet = c( "lambda", "alpha"), spark = c("reg_param", "elastic_net_param"), stan = c( NA, NA), - keras = c( "decay", NA), + keras = c( "decay", NA), stringsAsFactors = FALSE, row.names = c("penalty", "mixture") ) @@ -77,12 +77,20 @@ logistic_reg_glm_data <- const <- qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) trans <- object$fit$family$linkinv - res <- + res_2 <- tibble( - .pred_lower = trans(results$fit - const * results$se.fit), - .pred_upper = trans(results$fit + const * results$se.fit) + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) ) - if(object$spec$method$confint$extras$std_error) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$confint$extras$std_error) res$.std_error <- results$se.fit res }, @@ -199,21 +207,29 @@ logistic_reg_stan_data <- confint = list( pre = NULL, post = function(results, object) { - res <- + res_2 <- tibble( - .pred_lower = + lo = convert_stan_interval( results, level = object$spec$method$confint$extras$level ), - .pred_upper = + hi = convert_stan_interval( results, level = object$spec$method$confint$extras$level, lower = FALSE ), ) - if(object$spec$method$confint$extras$std_error) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$confint$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -229,21 +245,29 @@ logistic_reg_stan_data <- predint = list( pre = NULL, post = function(results, object) { - res <- + res_2 <- tibble( - .pred_lower = + lo = convert_stan_interval( results, level = object$spec$method$predint$extras$level ), - .pred_upper = + hi = convert_stan_interval( results, level = object$spec$method$predint$extras$level, lower = FALSE ), ) - if(object$spec$method$predint$extras$std_error) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$predint$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -327,4 +351,4 @@ logistic_reg_keras_data <- x = quote(as.matrix(new_data)) ) ) - ) \ No newline at end of file + ) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 33fc37918..8bc0deed8 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -58,14 +58,14 @@ #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. -#' This can have an effect on the model object results. When using -#' the `predict()` method in these cases, the return object type -#' depends on the value of `penalty`. If a single value is -#' given, the results will be a simple numeric vector. When -#' multiple values or no values for `penalty` are used in -#' `multinom_reg()`, the `predict()` method will return a data frame with -#' columns `values` and `lambda`. +#' multiple values (or no values) to the `penalty` argument. This +#' can have an effect on the model object results. When using the +#' `predict()` method in these cases, the return value depends on +#' the value of `penalty`. When using `predict()`, only a single +#' value of the penalty can be used. When predicting on multiple +#' penalties, the `multi_predict()` function can be used. It +#' returns a tibble with a list column called `.pred` that contains +#' a tibble with all of the penalty results. #' #' @note For models created using the spark engine, there are #' several differences to consider. First, only the formula @@ -188,54 +188,46 @@ organize_multnet_prob <- function(x, object) { } # ------------------------------------------------------------------------------ +# glmnet call stack for linear regression using `predict` when object has +# classes "_multnet" and "model_fit" (for class predictions): +# +# predict() +# predict._multnet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_class() +# predict_class._multnet() +# predict.multnet() + + +# glmnet call stack for linear regression using `multi_predict` when object has +# classes "_multnet" and "model_fit" (for class predictions): +# +# multi_predict() +# multi_predict._multnet(penalty = NULL) +# predict._multnet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._multnet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.multnet() -#' @export -predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) { - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) -} - -#' @export -predict_class._lognet <- function (object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} - -#' @export -predict_classprob._multnet <- function (object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} - -#' @export -predict_raw._multnet <- function (object, new_data, opts = list(), ...) { - object$spec <- eval_args(object$spec) - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) -} - +# ------------------------------------------------------------------------------ #' @export predict._multnet <- - function(object, new_data, type = NULL, opts = list(), penalty = NULL, ...) { - dots <- list(...) - if (is.null(penalty)) - penalty <- object$fit$lambda + function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { + + object$spec$args$penalty <- check_penalty(penalty, object, multi) - if (length(penalty) != 1) - stop("`penalty` should be a single numeric value. ", - "`multi_predict()` can be used to get multiple predictions ", - "per row of data.", call. = FALSE) object$spec <- eval_args(object$spec) res <- predict.model_fit( object = object, new_data = new_data, type = type, - opts = opts, - penalty = penalty + opts = opts ) - res -} - + res + } #' @importFrom dplyr full_join as_tibble arrange #' @importFrom tidyr gather @@ -255,8 +247,8 @@ multi_predict._multnet <- if (is.null(type)) type <- "class" - if (!(type %in% c("class", "prob", "link"))) { - stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE) + if (!(type %in% c("class", "prob", "link", "raw"))) { + stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE) } if (type == "prob") dots$type <- "response" @@ -296,6 +288,29 @@ multi_predict._multnet <- tibble(.pred = pred) } +#' @export +predict_class._multnet <- function (object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_class.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_classprob._multnet <- function (object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_classprob.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_raw._multnet <- function (object, new_data, opts = list(), ...) { + object$spec <- eval_args(object$spec) + predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) +} + + + +# ------------------------------------------------------------------------------ + +# This checks as a pre-processor in the model data object check_glmnet_lambda <- function(dat, object) { if (length(object$fit$lambda) > 1) stop( diff --git a/R/predict.R b/R/predict.R index 7b498bc6b..d0dc1f735 100644 --- a/R/predict.R +++ b/R/predict.R @@ -14,9 +14,23 @@ #' predict function that will be used when `type = "raw"`. The #' list should not include options for the model object or the #' new data being predicted. -#' @param ... Ignored. To pass arguments to pass to the underlying -#' function when `predict.model_fit(type = "raw")`, -#' use the `opts` argument. +#' @param ... Arguments to the underlying model's prediction +#' function cannot be passed here (see `opts`). There are some +#' `parsnip` related options that can be passed, depending on the +#' value of `type`. Possible arguments are: +#' \itemize{ +#' \item `level`: for `type`s of "conf_int" and "pred_int" this +#' is the parameter for the tail area of the intervals +#' (e.g. confidence level for confidence intervals). +#' Default value is 0.95. +#' \item `std_error`: add the standard error of fit or +#' prediction for `type`s of "conf_int" and "pred_int". +#' Default value is `FALSE`. +#' \item `quantile`: the quantile(s) for quantile regression +#' (not implemented yet) +#' \item `time`: the time(s) for hazard probability estimates +#' (not implemented yet) +#' } #' @details If "type" is not supplied to `predict()`, then a choice #' is made (`type = "numeric"` for regression models and #' `type = "class"` for classification). @@ -49,15 +63,18 @@ #' a list-column. Each list element contains a tibble with columns #' `.pred` and `.quantile` (and perhaps other columns). #' -#' Using `type = "raw"` with `predict.model_fit()` (or using -#' `predict_raw()`) will return the unadulterated results of the -#' prediction function. +#' Using `type = "raw"` with `predict.model_fit()` will return +#' the unadulterated results of the prediction function. #' #' In the case of Spark-based models, since table columns cannot #' contain dots, the same convention is used except 1) no dots #' appear in names and 2) vectors are never returned but #' type-specific prediction functions. #' +#' When the model fit failed and the error was captured, the +#' `predict()` function will return the same structure as above but +#' filled with missing values. This does not currently work for +#' multivariate models. #' @examples #' library(dplyr) #' @@ -90,10 +107,26 @@ #' @method predict model_fit #' @export predict.model_fit #' @export -predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) +predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) { + the_dots <- enquos(...) + if (any(names(the_dots) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + + other_args <- c("level", "std_error", "quantile") # "time" for survival probs later + is_pred_arg <- names(the_dots) %in% other_args + if (any(!is_pred_arg)) { + bad_args <- names(the_dots)[!is_pred_arg] + bad_args <- paste0("`", bad_args, "`", collapse = ", ") + stop("The ellipses are not used to pass args to the model function's ", + "predict function. These arguments cannot be used: ", + bad_args, call. = FALSE) + } + type <- check_pred_type(object, type) if (type != "raw" && length(opts) > 0) warning("`opts` is only used with `type = 'raw'` and was ignored.") @@ -214,14 +247,19 @@ prepare_data <- function(object, new_data) { #' multiple rows per sub-model. #' @keywords internal #' @export -multi_predict <- function(object, ...) +multi_predict <- function(object, ...) { + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } UseMethod("multi_predict") +} #' @keywords internal #' @export #' @rdname multi_predict multi_predict.default <- function(object, ...) - stop ("No `multi_predict` method exists for objects with classes ", + stop("No `multi_predict` method exists for objects with classes ", paste0("'", class(), "'", collapse = ", "), call. = FALSE) #' @export diff --git a/R/predict_class.R b/R/predict_class.R index b1cd24d8e..0e292a8b3 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -1,21 +1,26 @@ -#' Other predict methods. -#' -#' These are internal functions not meant to be directly called by the user. -#' -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_class model_fit -#' @export predict_class.model_fit -#' @export -predict_class.model_fit <- function (object, new_data, ...) { - if(object$spec$mode != "classification") +# Other predict methods. +# +# These are internal functions not meant to be directly called by the user. +# +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_class model_fit +# @export predict_class.model_fit +# @export +predict_class.model_fit <- function(object, new_data, ...) { + if (object$spec$mode != "classification") stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) if (!any(names(object$spec$method) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -28,7 +33,7 @@ predict_class.model_fit <- function (object, new_data, ...) { res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$class$post)) { + if (!is.null(object$spec$method$class$post)) { res <- object$spec$method$class$post(res, object) } @@ -43,9 +48,10 @@ predict_class.model_fit <- function (object, new_data, ...) { res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -predict_class <- function (object, ...) +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +predict_class <- function(object, ...) UseMethod("predict_class") + diff --git a/R/predict_classprob.R b/R/predict_classprob.R index de816c190..d902e7735 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -1,18 +1,23 @@ -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_classprob model_fit -#' @export predict_classprob.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_classprob model_fit +# @export predict_classprob.model_fit +# @export #' @importFrom tibble as_tibble is_tibble tibble -predict_classprob.model_fit <- function (object, new_data, ...) { - if(object$spec$mode != "classification") +predict_classprob.model_fit <- function(object, new_data, ...) { + if (object$spec$mode != "classification") stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) if (!any(names(object$spec$method) == "classprob")) stop("No class probability module defined for this model.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -25,7 +30,7 @@ predict_classprob.model_fit <- function (object, new_data, ...) { res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$classprob$post)) { + if (!is.null(object$spec$method$classprob$post)) { res <- object$spec$method$classprob$post(res, object) } @@ -39,9 +44,9 @@ predict_classprob.model_fit <- function (object, new_data, ...) { res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -predict_classprob <- function (object, ...) +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +predict_classprob <- function(object, ...) UseMethod("predict_classprob") diff --git a/R/predict_interval.R b/R/predict_interval.R index 390e97936..f38838e00 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -1,20 +1,24 @@ -#' @keywords internal -#' @rdname other_predict -#' @param level A single numeric value between zero and one for the -#' interval estimates. -#' @param std_error A single logical for wether the standard error should be -#' returned (assuming that the model can compute it). -#' @inheritParams predict.model_fit -#' @method predict_confint model_fit -#' @export predict_confint.model_fit -#' @export -predict_confint.model_fit <- - function (object, new_data, level = 0.95, std_error = FALSE, ...) { +# @keywords internal +# @rdname other_predict +# @param level A single numeric value between zero and one for the +# interval estimates. +# @param std_error A single logical for wether the standard error should be +# returned (assuming that the model can compute it). +# @inheritParams predict.model_fit +# @method predict_confint model_fit +# @export predict_confint.model_fit +# @export +predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { if (is.null(object$spec$method$confint)) stop("No confidence interval method defined for this ", "engine.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -29,7 +33,7 @@ predict_confint.model_fit <- res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$confint$post)) { + if (!is.null(object$spec$method$confint$post)) { res <- object$spec$method$confint$post(res, object) } @@ -38,28 +42,32 @@ predict_confint.model_fit <- res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -predict_confint <- function (object, ...) +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +predict_confint <- function(object, ...) UseMethod("predict_confint") -################################################################## +# ------------------------------------------------------------------------------ -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_predint model_fit -#' @export predict_predint.model_fit -#' @export -predict_predint.model_fit <- - function (object, new_data, level = 0.95, std_error = FALSE, ...) { +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_predint model_fit +# @export predict_predint.model_fit +# @export +predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { if (is.null(object$spec$method$predint)) stop("No prediction interval method defined for this ", "engine.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -75,7 +83,7 @@ predict_predint.model_fit <- res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$predint$post)) { + if (!is.null(object$spec$method$predint$post)) { res <- object$spec$method$predint$post(res, object) } @@ -84,14 +92,10 @@ predict_predint.model_fit <- res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -predict_predint <- function (object, ...) +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +predict_predint <- function(object, ...) UseMethod("predict_predint") - - - - diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 054fc3eb5..3a509546b 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -1,11 +1,11 @@ -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_numeric model_fit -#' @export predict_numeric.model_fit -#' @export - -predict_numeric.model_fit <- function (object, new_data, ...) { +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_numeric model_fit +# @export predict_numeric.model_fit +# @export + +predict_numeric.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "regression") stop("`predict_numeric()` is for predicting numeric outcomes. ", "Use `predict_class()` or `predict_classprob()` for ", @@ -14,6 +14,11 @@ predict_numeric.model_fit <- function (object, new_data, ...) { if (!any(names(object$spec$method) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -40,9 +45,9 @@ predict_numeric.model_fit <- function (object, new_data, ...) { } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict_numeric.model_fit -predict_numeric <- function (object, ...) +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict_numeric.model_fit +predict_numeric <- function(object, ...) UseMethod("predict_numeric") diff --git a/R/predict_quantile.R b/R/predict_quantile.R index ed8cfdbe3..698ddb4c8 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,41 +1,46 @@ -#' @keywords internal -#' @rdname other_predict -#' @param quant A vector of numbers between 0 and 1 for the quantile being -#' predicted. -#' @inheritParams predict.model_fit -#' @method predict_quantile model_fit -#' @export predict_quantile.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @param quant A vector of numbers between 0 and 1 for the quantile being +# predicted. +# @inheritParams predict.model_fit +# @method predict_quantile model_fit +# @export predict_quantile.model_fit +# @export predict_quantile.model_fit <- function (object, new_data, quantile = (1:9)/10, ...) { - + if (is.null(object$spec$method$quantile)) stop("No quantile prediction method defined for this ", "engine.", call. = FALSE) - + + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) - + # preprocess data if (!is.null(object$spec$method$quantile$pre)) new_data <- object$spec$method$quantile$pre(new_data, object) - + # Pass some extra arguments to be used in post-processor object$spec$method$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$quantile) - + res <- eval_tidy(pred_call) - + # post-process the predictions if(!is.null(object$spec$method$quantile$post)) { res <- object$spec$method$quantile$post(res, object) } - + res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit predict_quantile <- function (object, ...) UseMethod("predict_quantile") diff --git a/R/predict_raw.R b/R/predict_raw.R index c3f7dee25..315c9dd0a 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -1,9 +1,9 @@ -#' @rdname predict.model_fit -#' @inheritParams predict.model_fit -#' @method predict_raw model_fit -#' @export predict_raw.model_fit -#' @export -predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { +# @rdname predict.model_fit +# @inheritParams predict.model_fit +# @method predict_raw model_fit +# @export predict_raw.model_fit +# @export +predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { protected_args <- names(object$spec$method$raw$args) dup_args <- names(opts) %in% protected_args if (any(dup_args)) { @@ -13,27 +13,32 @@ predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { object$spec$method$raw$args <- c(object$spec$method$raw$args, opts) } - + if (!any(names(object$spec$method) == "raw")) stop("No raw prediction module defined for this model.", call. = FALSE) - + + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) - + # preprocess data if (!is.null(object$spec$method$raw$pre)) new_data <- object$spec$method$raw$pre(new_data, object) - + # create prediction call pred_call <- make_pred_call(object$spec$method$raw) - + res <- eval_tidy(pred_call) res } -#' @export -#' @rdname predict.model_fit -#' @inheritParams predict_raw.model_fit -predict_raw <- function (object, ...) +# @export +# @rdname predict.model_fit +# @inheritParams predict_raw.model_fit +predict_raw <- function(object, ...) UseMethod("predict_raw") diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index c4539fd37..c009ba8a3 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -106,14 +106,14 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. -This can have an effect on the model object results. When using -the \code{predict()} method in these cases, the return object type -depends on the value of \code{penalty}. If a single value is -given, the results will be a simple numeric vector. When -multiple values or no values for \code{penalty} are used in -\code{linear_reg()}, the \code{predict()} method will return a data frame with -columns \code{values} and \code{lambda}. +multiple values (or no values) to the \code{penalty} argument. This +can have an effect on the model object results. When using the +\code{predict()} method in these cases, the return value depends on +the value of \code{penalty}. When using \code{predict()}, only a single +value of the penalty can be used. When predicting on multiple +penalties, the \code{multi_predict()} function can be used. It +returns a tibble with a list column called \code{.pred} that contains +a tibble with all of the penalty results. For prediction, the \code{stan} engine can compute posterior intervals analogous to confidence and prediction intervals. In diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index e7ac89673..0d36b3b73 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -104,14 +104,14 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. -This can have an effect on the model object results. When using -the \code{predict()} method in these cases, the return object type -depends on the value of \code{penalty}. If a single value is -given, the results will be a simple numeric vector. When -multiple values or no values for \code{penalty} are used in -\code{logistic_reg()}, the \code{predict()} method will return a data frame with -columns \code{values} and \code{lambda}. +multiple values (or no values) to the \code{penalty} argument. This +can have an effect on the model object results. When using the +\code{predict()} method in these cases, the return value depends on +the value of \code{penalty}. When using \code{predict()}, only a single +value of the penalty can be used. When predicting on multiple +penalties, the \code{multi_predict()} function can be used. It +returns a tibble with a list column called \code{.pred} that contains +a tibble with all of the penalty results. For prediction, the \code{stan} engine can compute posterior intervals analogous to confidence and prediction intervals. In diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index ebcc03f49..f85f7a99c 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -95,14 +95,14 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. -This can have an effect on the model object results. When using -the \code{predict()} method in these cases, the return object type -depends on the value of \code{penalty}. If a single value is -given, the results will be a simple numeric vector. When -multiple values or no values for \code{penalty} are used in -\code{multinom_reg()}, the \code{predict()} method will return a data frame with -columns \code{values} and \code{lambda}. +multiple values (or no values) to the \code{penalty} argument. This +can have an effect on the model object results. When using the +\code{predict()} method in these cases, the return value depends on +the value of \code{penalty}. When using \code{predict()}, only a single +value of the penalty can be used. When predicting on multiple +penalties, the \code{multi_predict()} function can be used. It +returns a tibble with a list column called \code{.pred} that contains +a tibble with all of the penalty results. } \examples{ diff --git a/man/other_predict.Rd b/man/other_predict.Rd deleted file mode 100644 index 57e3bf3f2..000000000 --- a/man/other_predict.Rd +++ /dev/null @@ -1,67 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict_class.R, R/predict_classprob.R, -% R/predict_interval.R, R/predict_numeric.R, R/predict_quantile.R -\name{predict_class.model_fit} -\alias{predict_class.model_fit} -\alias{predict_class} -\alias{predict_classprob.model_fit} -\alias{predict_classprob} -\alias{predict_confint.model_fit} -\alias{predict_confint} -\alias{predict_predint.model_fit} -\alias{predict_predint} -\alias{predict_numeric.model_fit} -\alias{predict_numeric} -\alias{predict_quantile.model_fit} -\alias{predict_quantile} -\title{Other predict methods.} -\usage{ -\method{predict_class}{model_fit}(object, new_data, ...) - -predict_class(object, ...) - -\method{predict_classprob}{model_fit}(object, new_data, ...) - -predict_classprob(object, ...) - -\method{predict_confint}{model_fit}(object, new_data, level = 0.95, - std_error = FALSE, ...) - -predict_confint(object, ...) - -\method{predict_predint}{model_fit}(object, new_data, level = 0.95, - std_error = FALSE, ...) - -predict_predint(object, ...) - -\method{predict_numeric}{model_fit}(object, new_data, ...) - -predict_numeric(object, ...) - -\method{predict_quantile}{model_fit}(object, new_data, - quantile = (1:9)/10, ...) - -predict_quantile(object, ...) -} -\arguments{ -\item{object}{An object of class \code{model_fit}} - -\item{new_data}{A rectangular data object, such as a data frame.} - -\item{...}{Ignored. To pass arguments to pass to the underlying -function when \code{predict.model_fit(type = "raw")}, -use the \code{opts} argument.} - -\item{level}{A single numeric value between zero and one for the -interval estimates.} - -\item{std_error}{A single logical for wether the standard error should be -returned (assuming that the model can compute it).} - -\item{quant}{A vector of numbers between 0 and 1 for the quantile being -predicted.} -} -\description{ -These are internal functions not meant to be directly called by the user. -} -\keyword{internal} diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index 78999cbd0..9999429d8 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -1,17 +1,11 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict.R, R/predict_raw.R +% Please edit documentation in R/predict.R \name{predict.model_fit} \alias{predict.model_fit} -\alias{predict_raw.model_fit} -\alias{predict_raw} \title{Model predictions} \usage{ \method{predict}{model_fit}(object, new_data, type = NULL, opts = list(), ...) - -\method{predict_raw}{model_fit}(object, new_data, opts = list(), ...) - -predict_raw(object, ...) } \arguments{ \item{object}{An object of class \code{model_fit}} @@ -28,9 +22,23 @@ predict function that will be used when \code{type = "raw"}. The list should not include options for the model object or the new data being predicted.} -\item{...}{Ignored. To pass arguments to pass to the underlying -function when \code{predict.model_fit(type = "raw")}, -use the \code{opts} argument.} +\item{...}{Arguments to the underlying model's prediction +function cannot be passed here (see \code{opts}). There are some +\code{parsnip} related options that can be passed, depending on the +value of \code{type}. Possible arguments are: +\itemize{ +\item \code{level}: for \code{type}s of "conf_int" and "pred_int" this +is the parameter for the tail area of the intervals +(e.g. confidence level for confidence intervals). +Default value is 0.95. +\item \code{std_error}: add the standard error of fit or +prediction for \code{type}s of "conf_int" and "pred_int". +Default value is \code{FALSE}. +\item \code{quantile}: the quantile(s) for quantile regression +(not implemented yet) +\item \code{time}: the time(s) for hazard probability estimates +(not implemented yet) +}} } \value{ With the exception of \code{type = "raw"}, the results of @@ -54,14 +62,18 @@ Quantile predictions return a tibble with a column \code{.pred}, which is a list-column. Each list element contains a tibble with columns \code{.pred} and \code{.quantile} (and perhaps other columns). -Using \code{type = "raw"} with \code{predict.model_fit()} (or using -\code{predict_raw()}) will return the unadulterated results of the -prediction function. +Using \code{type = "raw"} with \code{predict.model_fit()} will return +the unadulterated results of the prediction function. In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots appear in names and 2) vectors are never returned but type-specific prediction functions. + +When the model fit failed and the error was captured, the +\code{predict()} function will return the same structure as above but +filled with missing values. This does not currently work for +multivariate models. } \description{ Apply a model to create different types of predictions. diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R index 81e30fd62..3d1f0e911 100644 --- a/tests/testthat/test_boost_tree_C50.R +++ b/tests/testthat/test_boost_tree_C50.R @@ -80,7 +80,7 @@ test_that('C5.0 prediction', { ) xy_pred <- predict(classes_xy$fit, newdata = lending_club[1:7, num_pred]) - expect_equal(xy_pred, predict_class(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(classes_xy, lending_club[1:7, num_pred])) }) @@ -97,9 +97,9 @@ test_that('C5.0 probabilities', { xy_pred <- predict(classes_xy$fit, newdata = as.data.frame(lending_club[1:7, num_pred]), type = "prob") xy_pred <- as_tibble(xy_pred) - expect_equal(xy_pred, predict_classprob(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(classes_xy, lending_club[1:7, num_pred])) - one_row <- predict_classprob(classes_xy, lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(classes_xy, lending_club[1, num_pred]) expect_equal(xy_pred[1,], one_row) }) diff --git a/tests/testthat/test_boost_tree_spark.R b/tests/testthat/test_boost_tree_spark.R index ce148620c..49e9892ec 100644 --- a/tests/testthat/test_boost_tree_spark.R +++ b/tests/testthat/test_boost_tree_spark.R @@ -58,7 +58,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_bt_te), + spark_reg_pred_num <- parsnip:::predict_numeric(spark_reg_fit, iris_bt_te), regexp = NA ) @@ -68,7 +68,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_bt_te), + spark_reg_num_dup <- parsnip:::predict_numeric(spark_reg_fit_dup, iris_bt_te), regexp = NA ) @@ -124,7 +124,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, churn_bt_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, churn_bt_te), regexp = NA ) @@ -134,7 +134,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_class <- predict_class(spark_class_fit_dup, churn_bt_te), + spark_class_dup_class <- parsnip:::predict_class(spark_class_fit_dup, churn_bt_te), regexp = NA ) @@ -156,7 +156,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, churn_bt_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, churn_bt_te), regexp = NA ) @@ -166,7 +166,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_classprob <- predict_classprob(spark_class_fit_dup, churn_bt_te), + spark_class_dup_classprob <- parsnip:::predict_classprob(spark_class_fit_dup, churn_bt_te), regexp = NA ) diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 0c6be9417..e740cfa12 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -66,7 +66,7 @@ test_that('xgboost classification prediction', { xy_pred <- predict(xy_fit$fit, newdata = xgb.DMatrix(data = as.matrix(iris[1:8, num_pred])), type = "class") xy_pred <- matrix(xy_pred, ncol = 3, byrow = TRUE) xy_pred <- factor(levels(iris$Species)[apply(xy_pred, 1, which.max)], levels = levels(iris$Species)) - expect_equal(xy_pred, predict_class(xy_fit, new_data = iris[1:8, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = iris[1:8, num_pred])) form_fit <- fit( iris_xgboost, @@ -78,7 +78,7 @@ test_that('xgboost classification prediction', { form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(iris[1:8, num_pred])), type = "class") form_pred <- matrix(form_pred, ncol = 3, byrow = TRUE) form_pred <- factor(levels(iris$Species)[apply(form_pred, 1, which.max)], levels = levels(iris$Species)) - expect_equal(form_pred, predict_class(form_fit, new_data = iris[1:8, num_pred])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = iris[1:8, num_pred])) }) @@ -141,7 +141,7 @@ test_that('xgboost regression prediction', { ) xy_pred <- predict(xy_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) form_fit <- fit( car_basic, @@ -151,7 +151,7 @@ test_that('xgboost regression prediction', { ) form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) - expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1])) + expect_equal(form_pred, parsnip:::predict_numeric(form_fit, new_data = mtcars[1:8, -1])) }) diff --git a/tests/testthat/test_failed_models.R b/tests/testthat/test_failed_models.R new file mode 100644 index 000000000..39fd4a88a --- /dev/null +++ b/tests/testthat/test_failed_models.R @@ -0,0 +1,74 @@ +library(testthat) +library(parsnip) +library(dplyr) +library(rlang) + +# ------------------------------------------------------------------------------ + +context("prediciton with failed models") + +# ------------------------------------------------------------------------------ + +iris_bad <- + iris %>% + mutate(big_num = Inf) + +data("lending_club") + +lending_club <- + lending_club %>% + dplyr::slice(1:200) %>% + mutate(big_num = Inf) + +lvl <- levels(lending_club$Class) + +# ------------------------------------------------------------------------------ + +ctrl <- fit_control(catch = TRUE) + +# ------------------------------------------------------------------------------ + +test_that('numeric model', { + lm_mod <- + linear_reg() %>% + set_engine("lm") %>% + fit(Sepal.Length ~ ., data = iris_bad, control = ctrl) + + expect_warning(num_res <- predict(lm_mod, iris_bad[1:11, -1])) + expect_equal(num_res, NULL) + + expect_warning(ci_res <- predict(lm_mod, iris_bad[1:11, -1], type = "conf_int")) + expect_equal(ci_res, NULL) + + expect_warning(pi_res <- predict(lm_mod, iris_bad[1:11, -1], type = "pred_int")) + expect_equal(pi_res, NULL) + +}) + +# ------------------------------------------------------------------------------ + +test_that('classification model', { + log_reg <- + logistic_reg() %>% + set_engine("glm") %>% + fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl) + + expect_warning( + cls_res <- + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class)) + ) + expect_equal(cls_res, NULL) + + expect_warning( + prb_res <- + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "prob") + ) + expect_equal(prb_res, NULL) + + expect_warning( + ci_res <- + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int") + ) + expect_equal(ci_res, NULL) +}) + diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 86d2e2bed..ba5f804b0 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -269,7 +269,7 @@ test_that('lm prediction', { control = ctrl ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -277,7 +277,7 @@ test_that('lm prediction', { data = iris, control = ctrl ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ])) res_mv <- fit( iris_basic, @@ -285,7 +285,7 @@ test_that('lm prediction', { data = iris, control = ctrl ) - expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,])) + expect_equal(mv_pred, parsnip:::predict_numeric(res_mv, iris[1:5,])) }) test_that('lm intervals', { diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index a22b6e73d..f03fcb38a 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -69,7 +69,7 @@ test_that('glmnet prediction, single lambda', { s = iris_basic$spec$args$penalty) uni_pred <- unname(uni_pred[,1]) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -87,7 +87,7 @@ test_that('glmnet prediction, single lambda', { s = res_form$spec$spec$args$penalty) form_pred <- unname(form_pred[,1]) - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) @@ -115,7 +115,7 @@ test_that('glmnet prediction, multiple lambda', { mult_pred$lambda <- rep(lams, each = 5) mult_pred <- mult_pred[,-2] - expect_equal(mult_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(mult_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_mult, @@ -135,7 +135,7 @@ test_that('glmnet prediction, multiple lambda', { form_pred$lambda <- rep(lams, each = 5) form_pred <- form_pred[,-2] - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) test_that('glmnet prediction, all lambda', { @@ -157,7 +157,7 @@ test_that('glmnet prediction, all lambda', { all_pred$lambda <- rep(res_xy$fit$lambda, each = 5) all_pred <- all_pred[,-2] - expect_equal(all_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(all_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) # test that the lambda seq is in the right order (since no docs on this) tmp_pred <- predict(res_xy$fit, newx = as.matrix(iris[1:5, num_pred]), @@ -180,7 +180,7 @@ test_that('glmnet prediction, all lambda', { form_pred$lambda <- rep(res_form$fit$lambda, each = 5) form_pred <- form_pred[,-2] - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) @@ -198,10 +198,56 @@ test_that('submodel prediction', { mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], penalty = .1) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred"]], unname(pred_glmn[,1])) - + expect_error( - multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), + multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), "Did you mean" ) + + reg_fit <- + linear_reg(penalty = c(0, 0.01, 0.1)) %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) + + + pred_glmn_all <- + predict(reg_fit$fit, as.matrix(mtcars[1:2, -1])) %>% + as.data.frame() %>% + stack() %>% + dplyr::arrange(ind) + + + mp_res_all <- + multi_predict(reg_fit, new_data = mtcars[1:2, -1]) %>% + tidyr::unnest() + + expect_equal(sort(mp_res_all$.pred), sort(pred_glmn_all$values)) + +}) + + +test_that('error traps', { + + skip_if_not_installed("glmnet") + + expect_error( + linear_reg(penalty = .1) %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ], penalty = .2) + ) + expect_error( + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ], penalty = 0:1) + ) + expect_error( + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ]) + ) + }) diff --git a/tests/testthat/test_linear_reg_spark.R b/tests/testthat/test_linear_reg_spark.R index f1a033beb..28859a030 100644 --- a/tests/testthat/test_linear_reg_spark.R +++ b/tests/testthat/test_linear_reg_spark.R @@ -42,7 +42,7 @@ test_that('spark execution', { ) expect_error( - spark_pred_num <- predict_numeric(spark_fit, iris_linreg_te), + spark_pred_num <- parsnip:::predict_numeric(spark_fit, iris_linreg_te), regexp = NA ) diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R index 4891e65e1..8fff7084e 100644 --- a/tests/testthat/test_linear_reg_stan.R +++ b/tests/testthat/test_linear_reg_stan.R @@ -67,7 +67,7 @@ test_that('stan prediction', { control = quiet_ctrl ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]), tolerance = 0.001) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred]), tolerance = 0.001) res_form <- fit( iris_basic, @@ -75,7 +75,7 @@ test_that('stan prediction', { data = iris, control = quiet_ctrl ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ]), tolerance = 0.001) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ]), tolerance = 0.001) }) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index c74b0c492..7971e1c40 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -275,7 +275,7 @@ test_that('glm prediction', { xy_pred <- ifelse(xy_pred >= 0.5, "good", "bad") xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_class(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(classes_xy, lending_club[1:7, num_pred])) }) @@ -289,9 +289,9 @@ test_that('glm probabilities', { xy_pred <- predict(classes_xy$fit, newdata = lending_club[1:7, num_pred], type = "response") xy_pred <- tibble(bad = 1 - xy_pred, good = xy_pred) - expect_equal(xy_pred, predict_classprob(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(classes_xy, lending_club[1:7, num_pred])) - one_row <- predict_classprob(classes_xy, lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(classes_xy, lending_club[1, num_pred]) expect_equal(xy_pred[1,], one_row) }) @@ -323,8 +323,10 @@ test_that('glm intervals', { level = 0.93, std_error = TRUE) - expect_equivalent(confidence_parsnip$.pred_lower, lower_glm) - expect_equivalent(confidence_parsnip$.pred_upper, upper_glm) + expect_equivalent(confidence_parsnip$.pred_lower_good, lower_glm) + expect_equivalent(confidence_parsnip$.pred_upper_good, upper_glm) + expect_equivalent(confidence_parsnip$.pred_lower_bad, 1 - upper_glm) + expect_equivalent(confidence_parsnip$.pred_upper_bad, 1 - lower_glm) expect_equivalent(confidence_parsnip$.std_error, pred_glm$se.fit) }) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index e59533209..510ace9c9 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -64,7 +64,7 @@ test_that('glmnet prediction, one lambda', { uni_pred <- factor(uni_pred, levels = levels(lending_club$Class)) uni_pred <- unname(uni_pred) - expect_equal(uni_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(uni_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), @@ -84,7 +84,7 @@ test_that('glmnet prediction, one lambda', { form_pred <- factor(form_pred, levels = levels(lending_club$Class)) form_pred <- unname(form_pred) - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -112,7 +112,7 @@ test_that('glmnet prediction, mulitiple lambda', { mult_pred$lambda <- rep(lams, each = 7) mult_pred <- mult_pred[, -2] - expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = lams) %>% set_engine("glmnet"), @@ -134,7 +134,7 @@ test_that('glmnet prediction, mulitiple lambda', { form_pred$lambda <- rep(lams, each = 7) form_pred <- form_pred[, -2] - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -159,7 +159,7 @@ test_that('glmnet prediction, no lambda', { mult_pred$lambda <- rep(xy_fit$fit$lambda, each = 7) mult_pred <- mult_pred[, -2] - expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% set_engine("glmnet", nlambda = 11), @@ -180,7 +180,7 @@ test_that('glmnet prediction, no lambda', { form_pred$values <- factor(form_pred$values, levels = levels(lending_club$Class)) form_pred$lambda <- rep(res_form$fit$lambda, each = 7) form_pred <- form_pred[, -2] - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -202,7 +202,7 @@ test_that('glmnet probabilities, one lambda', { s = 0.1, type = "response")[,1] uni_pred <- tibble(bad = 1 - uni_pred, good = uni_pred) - expect_equal(uni_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(uni_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), @@ -219,9 +219,9 @@ test_that('glmnet probabilities, one lambda', { newx = form_mat, s = 0.1, type = "response")[, 1] form_pred <- tibble(bad = 1 - form_pred, good = form_pred) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) - one_row <- predict_classprob(res_form, lending_club[1, c("funded_amnt", "int_rate")]) + one_row <- parsnip:::predict_classprob(res_form, lending_club[1, c("funded_amnt", "int_rate")]) expect_equal(form_pred[1,], one_row) }) @@ -247,7 +247,7 @@ test_that('glmnet probabilities, mulitiple lambda', { mult_pred <- tibble(bad = 1 - mult_pred$values, good = mult_pred$values) mult_pred$lambda <- rep(lams, each = 7) - expect_equal(mult_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = lams) %>% set_engine("glmnet"), @@ -267,7 +267,7 @@ test_that('glmnet probabilities, mulitiple lambda', { form_pred <- tibble(bad = 1 - form_pred$values, good = form_pred$values) form_pred$lambda <- rep(lams, each = 7) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -291,7 +291,7 @@ test_that('glmnet probabilities, no lambda', { mult_pred <- tibble(bad = 1 - mult_pred$values, good = mult_pred$values) mult_pred$lambda <- rep(xy_fit$fit$lambda, each = 7) - expect_equal(mult_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% set_engine("glmnet"), @@ -311,7 +311,7 @@ test_that('glmnet probabilities, no lambda', { form_pred <- tibble(bad = 1 - form_pred$values, good = form_pred$values) form_pred$lambda <- rep(res_form$fit$lambda, each = 7) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index b338f4ea6..eb3238902 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -161,7 +161,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>% + keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) @@ -178,7 +178,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>% + keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -1], type = "prob") diff --git a/tests/testthat/test_logistic_reg_spark.R b/tests/testthat/test_logistic_reg_spark.R index 9d435dc85..b9ac7a1fa 100644 --- a/tests/testthat/test_logistic_reg_spark.R +++ b/tests/testthat/test_logistic_reg_spark.R @@ -56,7 +56,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, churn_logit_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, churn_logit_te), regexp = NA ) @@ -73,7 +73,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, churn_logit_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, churn_logit_te), regexp = NA ) diff --git a/tests/testthat/test_logistic_reg_stan.R b/tests/testthat/test_logistic_reg_stan.R index 0a322d1e4..3373b3ce3 100644 --- a/tests/testthat/test_logistic_reg_stan.R +++ b/tests/testthat/test_logistic_reg_stan.R @@ -63,7 +63,7 @@ test_that('stan_glm prediction', { xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% @@ -80,7 +80,7 @@ test_that('stan_glm prediction', { form_pred <- unname(form_pred) form_pred <- ifelse(form_pred >= 0.5, "good", "bad") form_pred <- factor(form_pred, levels = levels(lending_club$Class)) - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -102,7 +102,7 @@ test_that('stan_glm probability', { xy_pred <- xy_fit$fit$family$linkinv(xy_pred) xy_pred <- tibble(bad = 1 - xy_pred, good = xy_pred) - expect_equal(xy_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% @@ -117,7 +117,7 @@ test_that('stan_glm probability', { newdata = lending_club[1:7, c("funded_amnt", "int_rate")]) form_pred <- xy_fit$fit$family$linkinv(form_pred) form_pred <- tibble(bad = 1 - form_pred, good = form_pred) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -156,8 +156,10 @@ test_that('stan intervals', { stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) stan_std <- apply(stan_post, 2, sd) - expect_equivalent(confidence_parsnip$.pred_lower, stan_lower) - expect_equivalent(confidence_parsnip$.pred_upper, stan_upper) + expect_equivalent(confidence_parsnip$.pred_lower_good, stan_lower) + expect_equivalent(confidence_parsnip$.pred_upper_good, stan_upper) + expect_equivalent(confidence_parsnip$.pred_lower_bad, 1 - stan_upper) + expect_equivalent(confidence_parsnip$.pred_upper_bad, 1 - stan_lower) expect_equivalent(confidence_parsnip$.std_error, stan_std) stan_pred_post <- @@ -168,8 +170,10 @@ test_that('stan intervals', { stan_pred_upper <- apply(stan_pred_post, 2, quantile, prob = 0.965) stan_pred_std <- apply(stan_pred_post, 2, sd) - expect_equivalent(prediction_parsnip$.pred_lower, stan_pred_lower) - expect_equivalent(prediction_parsnip$.pred_upper, stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_lower_good, stan_pred_lower) + expect_equivalent(prediction_parsnip$.pred_upper_good, stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_lower_bad, 1 - stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_upper_bad, 1 - stan_pred_lower) expect_equivalent(prediction_parsnip$.std_error, stan_pred_std, tolerance = 0.1) }) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index b23330516..7ea9213a3 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -185,7 +185,7 @@ test_that('mars prediction', { control = ctrl ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -193,7 +193,7 @@ test_that('mars prediction', { data = iris, control = ctrl ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ])) res_mv <- fit( iris_basic, @@ -201,7 +201,7 @@ test_that('mars prediction', { data = iris, control = ctrl ) - expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,])) + expect_equal(mv_pred, parsnip:::predict_numeric(res_mv, iris[1:5,])) }) @@ -270,7 +270,7 @@ test_that('classification', { regexp = NA ) expect_true(!is.null(glm_mars$fit$glm.list)) - parsnip_pred <- predict_classprob(glm_mars, new_data = lending_club[1:5, -ncol(lending_club)]) + parsnip_pred <- parsnip:::predict_classprob(glm_mars, new_data = lending_club[1:5, -ncol(lending_club)]) earth_pred <- c(0.95631355972526, 0.971917781277731, 0.894245392500336, 0.962667553751077, diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index a2022923e..fead31752 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -69,7 +69,7 @@ test_that('keras classification prediction', { control = ctrl ) - xy_pred <- predict_classes(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) + xy_pred <- keras::predict_classes(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) xy_pred <- factor(levels(iris$Species)[xy_pred + 1], levels = levels(iris$Species)) expect_equal(xy_pred, predict(xy_fit, new_data = iris[1:8, num_pred], type = "class")[[".pred_class"]]) @@ -82,7 +82,7 @@ test_that('keras classification prediction', { control = ctrl ) - form_pred <- predict_classes(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) + form_pred <- keras::predict_classes(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) form_pred <- factor(levels(iris$Species)[form_pred + 1], levels = levels(iris$Species)) expect_equal(form_pred, predict(form_fit, new_data = iris[1:8, num_pred], type = "class")[[".pred_class"]]) @@ -101,7 +101,7 @@ test_that('keras classification probabilities', { control = ctrl ) - xy_pred <- predict_proba(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) + xy_pred <- keras::predict_proba(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) xy_pred <- as_tibble(xy_pred) colnames(xy_pred) <- paste0(".pred_", levels(iris$Species)) expect_equal(xy_pred, predict(xy_fit, new_data = iris[1:8, num_pred], type = "prob")) @@ -115,7 +115,7 @@ test_that('keras classification probabilities', { control = ctrl ) - form_pred <- predict_proba(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) + form_pred <- keras::predict_proba(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) form_pred <- as_tibble(form_pred) colnames(form_pred) <- paste0(".pred_", levels(iris$Species)) expect_equal(form_pred, predict(form_fit, new_data = iris[1:8, num_pred], type = "prob")) @@ -218,7 +218,7 @@ test_that('multivariate nnet formula', { data = nn_dat[-(1:5),] ) expect_equal(length(unlist(keras::get_weights(nnet_form$fit))), 24) - nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_pred <- parsnip:::predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_pred), 3) expect_equal(nrow(nnet_form_pred), 5) expect_equal(names(nnet_form_pred), c("V1", "V2", "V3")) @@ -233,7 +233,7 @@ test_that('multivariate nnet formula', { y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(unlist(keras::get_weights(nnet_xy$fit))), 24) - nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_xy <- parsnip:::predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_xy), 3) expect_equal(nrow(nnet_form_xy), 5) expect_equal(names(nnet_form_xy), c("V1", "V2", "V3")) diff --git a/tests/testthat/test_mlp_nnet.R b/tests/testthat/test_mlp_nnet.R index a112086fe..6a5022786 100644 --- a/tests/testthat/test_mlp_nnet.R +++ b/tests/testthat/test_mlp_nnet.R @@ -64,7 +64,7 @@ test_that('nnet classification prediction', { xy_pred <- predict(xy_fit$fit, newdata = iris[1:8, num_pred], type = "class") xy_pred <- factor(xy_pred, levels = levels(iris$Species)) - expect_equal(xy_pred, predict_class(xy_fit, new_data = iris[1:8, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = iris[1:8, num_pred])) form_fit <- fit( iris_nnet, @@ -75,7 +75,7 @@ test_that('nnet classification prediction', { form_pred <- predict(form_fit$fit, newdata = iris[1:8, num_pred], type = "class") form_pred <- factor(form_pred, levels = levels(iris$Species)) - expect_equal(form_pred, predict_class(form_fit, new_data = iris[1:8, num_pred])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = iris[1:8, num_pred])) }) @@ -141,7 +141,7 @@ test_that('nnet regression prediction', { xy_pred <- predict(xy_fit$fit, newdata = mtcars[1:8, -1])[,1] xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) form_fit <- fit( car_basic, @@ -152,7 +152,7 @@ test_that('nnet regression prediction', { form_pred <- predict(form_fit$fit, newdata = mtcars[1:8, -1])[,1] form_pred <- unname(form_pred) - expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1])) + expect_equal(form_pred, parsnip:::predict_numeric(form_fit, new_data = mtcars[1:8, -1])) }) # ------------------------------------------------------------------------------ @@ -175,7 +175,7 @@ test_that('multivariate nnet formula', { data = nn_dat[-(1:5),] ) expect_equal(length(nnet_form$fit$wts), 24) - nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_pred <- parsnip:::predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_pred), 3) expect_equal(nrow(nnet_form_pred), 5) expect_equal(names(nnet_form_pred), c("V1", "V2", "V3")) @@ -192,7 +192,7 @@ test_that('multivariate nnet formula', { y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(nnet_xy$fit$wts), 24) - nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_xy <- parsnip:::predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_xy), 3) expect_equal(nrow(nnet_form_xy), 5) expect_equal(names(nnet_form_xy), c("V1", "V2", "V3")) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index 6043aae6e..af6999434 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -58,7 +58,7 @@ test_that('glmnet prediction, one lambda', { uni_pred <- factor(uni_pred[,1], levels = levels(iris$Species)) uni_pred <- unname(uni_pred) - expect_equal(uni_pred, predict_class(xy_fit, iris[rows, 1:4])) + expect_equal(uni_pred, parsnip:::predict_class.model_fit(xy_fit, iris[rows, 1:4])) expect_equal(uni_pred, predict(xy_fit, iris[rows, 1:4], type = "class")$.pred_class) res_form <- fit( @@ -77,7 +77,7 @@ test_that('glmnet prediction, one lambda', { s = res_form$spec$args$penalty, type = "class") form_pred <- factor(form_pred[,1], levels = levels(iris$Species)) - expect_equal(form_pred, predict_class(res_form, iris[rows, c("Sepal.Width", "Petal.Width")])) + expect_equal(form_pred, parsnip:::predict_class.model_fit(res_form, iris[rows, c("Sepal.Width", "Petal.Width")])) expect_equal(form_pred, predict(res_form, iris[rows, c("Sepal.Width", "Petal.Width")], type = "class")$.pred_class) }) diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index a6937e4d6..eafda5871 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -154,7 +154,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>% + keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) @@ -171,7 +171,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>% + keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -5], type = "prob") diff --git a/tests/testthat/test_multinom_reg_spark.R b/tests/testthat/test_multinom_reg_spark.R index 0954e52a0..e28238207 100644 --- a/tests/testthat/test_multinom_reg_spark.R +++ b/tests/testthat/test_multinom_reg_spark.R @@ -45,7 +45,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, iris_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, iris_te), regexp = NA ) @@ -62,7 +62,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, iris_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, iris_te), regexp = NA ) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 1d764d0aa..cc483a156 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -20,6 +20,8 @@ quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) test_that('kknn execution', { skip_if_not_installed("kknn") + library(kknn) + # continuous # expect no error @@ -74,7 +76,7 @@ test_that('kknn prediction', { newdata = iris[1:5, num_pred] ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) # nominal res_xy_nom <- fit_xy( @@ -89,7 +91,7 @@ test_that('kknn prediction', { newdata = iris[1:5, c("Sepal.Length", "Petal.Width")] ) - expect_equal(uni_pred_nom, predict_class(res_xy_nom, iris[1:5, c("Sepal.Length", "Petal.Width")])) + expect_equal(uni_pred_nom, parsnip:::predict_class(res_xy_nom, iris[1:5, c("Sepal.Length", "Petal.Width")])) # continuous - formula interface res_form <- fit( @@ -104,5 +106,5 @@ test_that('kknn prediction', { newdata = iris[1:5,] ) - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index 45a6a9932..da0294d1f 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -109,7 +109,7 @@ test_that('nullmodel prediction', { Petal.Length ~ log(Sepal.Width) + Species, data = iris ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ])) # Multivariate y res <- fit( @@ -118,7 +118,7 @@ test_that('nullmodel prediction', { data = mtcars ) - expect_equal(mw_pred, predict_numeric(res, mtcars[1:5, ])) + expect_equal(mw_pred, parsnip:::predict_numeric(res, mtcars[1:5, ])) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index cd10d2add..51bcc8f91 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -1,6 +1,7 @@ library(testthat) library(parsnip) library(tibble) +library(dplyr) # ------------------------------------------------------------------------------ @@ -31,30 +32,48 @@ lr_fit_2 <- test_that('regression predictions', { expect_true(is_tibble(predict(lm_fit, new_data = iris[1:5,-1]))) - expect_true(is.vector(predict_numeric(lm_fit, new_data = iris[1:5,-1]))) + expect_true(is.vector(parsnip:::predict_numeric(lm_fit, new_data = iris[1:5,-1]))) expect_equal(names(predict(lm_fit, new_data = iris[1:5,-1])), ".pred") }) test_that('classification predictions', { expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) - expect_true(is.factor(predict_class(lr_fit, new_data = class_dat[1:5,-1]))) + expect_true(is.factor(parsnip:::predict_class(lr_fit, new_data = class_dat[1:5,-1]))) expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1])), ".pred_class") expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1], type = "prob"))) - expect_true(is_tibble(predict_classprob(lr_fit, new_data = class_dat[1:5,-1]))) + expect_true(is_tibble(parsnip:::predict_classprob(lr_fit, new_data = class_dat[1:5,-1]))) expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1], type = "prob")), c(".pred_high", ".pred_low")) }) test_that('non-standard levels', { expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) - expect_true(is.factor(predict_class(lr_fit, new_data = class_dat[1:5,-1]))) + expect_true(is.factor(parsnip:::predict_class(lr_fit, new_data = class_dat[1:5,-1]))) expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1])), ".pred_class") expect_true(is_tibble(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob"))) - expect_true(is_tibble(predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1]))) + expect_true(is_tibble(parsnip:::predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1]))) expect_equal(names(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")), c(".pred_2low", ".pred_high+values")) - expect_equal(names(predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1])), + expect_equal(names(parsnip:::predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1])), c("2low", "high+values")) }) + +# ------------------------------------------------------------------------------ + +test_that('bad predict args', { + lm_model <- + linear_reg() %>% + set_engine("lm") %>% + fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32)) + + pred_cars <- + mtcars %>% + dplyr::slice(1:10) %>% + dplyr::select(-mpg) + + expect_error(predict(lm_model, pred_cars, yes = "no")) + expect_error(predict(lm_model, pred_cars, type = "conf_int", level = 0.95, yes = "no")) +}) + diff --git a/tests/testthat/test_rand_forest_randomForest.R b/tests/testthat/test_rand_forest_randomForest.R index 35937b244..cfba216b7 100644 --- a/tests/testthat/test_rand_forest_randomForest.R +++ b/tests/testthat/test_rand_forest_randomForest.R @@ -90,7 +90,7 @@ test_that('randomForest classification prediction', { xy_pred <- predict(xy_fit$fit, newdata = lending_club[1:6, num_pred]) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) form_fit <- fit( lc_basic, @@ -101,7 +101,7 @@ test_that('randomForest classification prediction', { form_pred <- predict(form_fit$fit, newdata = lending_club[1:6, c("funded_amnt", "int_rate")]) form_pred <- unname(form_pred) - expect_equal(form_pred, predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) }) test_that('randomForest classification probabilities', { @@ -117,9 +117,9 @@ test_that('randomForest classification probabilities', { xy_pred <- predict(xy_fit$fit, newdata = lending_club[1:6, num_pred], type = "prob") xy_pred <- as_tibble(as.data.frame(xy_pred)) - expect_equal(xy_pred, predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) - one_row <- predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) expect_equivalent(xy_pred[1,], one_row) form_fit <- fit( @@ -131,7 +131,7 @@ test_that('randomForest classification probabilities', { form_pred <- predict(form_fit$fit, newdata = lending_club[1:6, c("funded_amnt", "int_rate")], type = "prob") form_pred <- as_tibble(as.data.frame(form_pred)) - expect_equal(form_pred, predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) }) @@ -209,6 +209,6 @@ test_that('randomForest regression prediction', { xy_pred <- predict(xy_fit$fit, newdata = tail(mtcars)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars))) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = tail(mtcars))) }) diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 82d767c4b..3e2f963d7 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -96,7 +96,7 @@ test_that('ranger classification prediction', { xy_pred <- predict(xy_fit$fit, data = lending_club[1:6, num_pred])$prediction xy_pred <- colnames(xy_pred)[apply(xy_pred, 1, which.max)] xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) - expect_equal(xy_pred, predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) form_fit <- fit( rand_forest() %>% set_engine("ranger"), @@ -109,7 +109,7 @@ test_that('ranger classification prediction', { form_pred <- predict(form_fit$fit, data = lending_club[1:6, c("funded_amnt", "int_rate")])$prediction form_pred <- colnames(form_pred)[apply(form_pred, 1, which.max)] form_pred <- factor(form_pred, levels = levels(lending_club$Class)) - expect_equal(form_pred, predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) }) @@ -128,9 +128,9 @@ test_that('ranger classification probabilities', { xy_pred <- predict(xy_fit$fit, data = lending_club[1:6, num_pred])$predictions xy_pred <- as_tibble(xy_pred) - expect_equal(xy_pred, predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) - one_row <- predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) expect_equivalent(xy_pred[1,], one_row) form_fit <- fit( @@ -143,7 +143,7 @@ test_that('ranger classification probabilities', { form_pred <- predict(form_fit$fit, data = lending_club[1:6, c("funded_amnt", "int_rate")])$predictions form_pred <- as_tibble(form_pred) - expect_equal(form_pred, predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) no_prob_model <- fit_xy( rand_forest() %>% set_engine("ranger", probability = FALSE), @@ -154,7 +154,7 @@ test_that('ranger classification probabilities', { ) expect_error( - predict_classprob(no_prob_model, new_data = lending_club[1:6, num_pred]) + parsnip:::predict_classprob(no_prob_model, new_data = lending_club[1:6, num_pred]) ) }) @@ -229,7 +229,7 @@ test_that('ranger regression prediction', { xy_pred <- predict(xy_fit$fit, data = tail(mtcars[, -1]))$prediction - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars[, -1]))) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = tail(mtcars[, -1]))) }) diff --git a/tests/testthat/test_rand_forest_spark.R b/tests/testthat/test_rand_forest_spark.R index 4184e6abf..da0b5fab6 100644 --- a/tests/testthat/test_rand_forest_spark.R +++ b/tests/testthat/test_rand_forest_spark.R @@ -58,7 +58,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_rf_te), + spark_reg_pred_num <- parsnip:::predict_numeric(spark_reg_fit, iris_rf_te), regexp = NA ) @@ -68,7 +68,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_rf_te), + spark_reg_num_dup <- parsnip:::predict_numeric(spark_reg_fit_dup, iris_rf_te), regexp = NA ) @@ -124,7 +124,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, churn_rf_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, churn_rf_te), regexp = NA ) @@ -134,7 +134,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_class <- predict_class(spark_class_fit_dup, churn_rf_te), + spark_class_dup_class <- parsnip:::predict_class(spark_class_fit_dup, churn_rf_te), regexp = NA ) @@ -156,7 +156,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, churn_rf_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, churn_rf_te), regexp = NA ) @@ -166,7 +166,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_classprob <- predict_classprob(spark_class_fit_dup, churn_rf_te), + spark_class_dup_classprob <- parsnip:::predict_classprob(spark_class_fit_dup, churn_rf_te), regexp = NA )