From 93fb310c0de4c3e6eeb2b6d10f321da993711ce6 Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 25 Feb 2019 15:45:13 -0500 Subject: [PATCH 01/15] unexported prediction modules --- NAMESPACE | 21 ------ R/predict.R | 5 +- R/predict_class.R | 28 ++++---- R/predict_classprob.R | 20 +++--- R/predict_interval.R | 48 ++++++------- R/predict_numeric.R | 20 +++--- R/predict_quantile.R | 24 +++---- R/predict_raw.R | 16 ++--- man/other_predict.Rd | 67 ------------------- man/predict.model_fit.Rd | 13 +--- tests/testthat/test_boost_tree_C50.R | 6 +- tests/testthat/test_boost_tree_spark.R | 12 ++-- tests/testthat/test_boost_tree_xgboost.R | 8 +-- tests/testthat/test_linear_reg.R | 6 +- tests/testthat/test_linear_reg_glmnet.R | 12 ++-- tests/testthat/test_linear_reg_spark.R | 2 +- tests/testthat/test_linear_reg_stan.R | 4 +- tests/testthat/test_logistic_reg.R | 6 +- tests/testthat/test_logistic_reg_glmnet.R | 26 +++---- tests/testthat/test_logistic_reg_keras.R | 4 +- tests/testthat/test_logistic_reg_spark.R | 4 +- tests/testthat/test_logistic_reg_stan.R | 8 +-- tests/testthat/test_mars.R | 8 +-- tests/testthat/test_mlp_keras.R | 12 ++-- tests/testthat/test_mlp_nnet.R | 12 ++-- tests/testthat/test_multinom_reg_glmnet.R | 4 +- tests/testthat/test_multinom_reg_keras.R | 4 +- tests/testthat/test_multinom_reg_spark.R | 4 +- tests/testthat/test_nearest_neighbor_kknn.R | 6 +- tests/testthat/test_nullmodel.R | 4 +- tests/testthat/test_predict_formats.R | 12 ++-- .../testthat/test_rand_forest_randomForest.R | 12 ++-- tests/testthat/test_rand_forest_ranger.R | 14 ++-- tests/testthat/test_rand_forest_spark.R | 12 ++-- 34 files changed, 184 insertions(+), 280 deletions(-) delete mode 100644 man/other_predict.Rd diff --git a/NAMESPACE b/NAMESPACE index 662cfaff0..b57fd5df0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,19 +17,12 @@ S3method(predict,model_fit) S3method(predict,model_spec) S3method(predict,nullmodel) S3method(predict_class,"_lognet") -S3method(predict_class,model_fit) S3method(predict_classprob,"_lognet") S3method(predict_classprob,"_multnet") -S3method(predict_classprob,model_fit) -S3method(predict_confint,model_fit) S3method(predict_numeric,"_elnet") -S3method(predict_numeric,model_fit) -S3method(predict_predint,model_fit) -S3method(predict_quantile,model_fit) S3method(predict_raw,"_elnet") S3method(predict_raw,"_lognet") S3method(predict_raw,"_multnet") -S3method(predict_raw,model_fit) S3method(print,boost_tree) S3method(print,decision_tree) S3method(print,linear_reg) @@ -102,20 +95,6 @@ export(nearest_neighbor) export(null_model) export(nullmodel) export(predict.model_fit) -export(predict_class) -export(predict_class.model_fit) -export(predict_classprob) -export(predict_classprob.model_fit) -export(predict_confint) -export(predict_confint.model_fit) -export(predict_numeric) -export(predict_numeric.model_fit) -export(predict_predint) -export(predict_predint.model_fit) -export(predict_quantile) -export(predict_quantile.model_fit) -export(predict_raw) -export(predict_raw.model_fit) export(rand_forest) export(rpart_train) export(set_args) diff --git a/R/predict.R b/R/predict.R index 7b498bc6b..28f4ad505 100644 --- a/R/predict.R +++ b/R/predict.R @@ -49,9 +49,8 @@ #' a list-column. Each list element contains a tibble with columns #' `.pred` and `.quantile` (and perhaps other columns). #' -#' Using `type = "raw"` with `predict.model_fit()` (or using -#' `predict_raw()`) will return the unadulterated results of the -#' prediction function. +#' Using `type = "raw"` with `predict.model_fit()` will return +#' the unadulterated results of the prediction function. #' #' In the case of Spark-based models, since table columns cannot #' contain dots, the same convention is used except 1) no dots diff --git a/R/predict_class.R b/R/predict_class.R index b1cd24d8e..39a051e5e 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -1,13 +1,13 @@ -#' Other predict methods. -#' -#' These are internal functions not meant to be directly called by the user. -#' -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_class model_fit -#' @export predict_class.model_fit -#' @export +# Other predict methods. +# +# These are internal functions not meant to be directly called by the user. +# +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_class model_fit +# @export predict_class.model_fit +# @export predict_class.model_fit <- function (object, new_data, ...) { if(object$spec$mode != "classification") stop("`predict.model_fit()` is for predicting factor outcomes.", @@ -43,9 +43,9 @@ predict_class.model_fit <- function (object, new_data, ...) { res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit predict_class <- function (object, ...) UseMethod("predict_class") diff --git a/R/predict_classprob.R b/R/predict_classprob.R index de816c190..ff481fd58 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -1,9 +1,9 @@ -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_classprob model_fit -#' @export predict_classprob.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_classprob model_fit +# @export predict_classprob.model_fit +# @export #' @importFrom tibble as_tibble is_tibble tibble predict_classprob.model_fit <- function (object, new_data, ...) { if(object$spec$mode != "classification") @@ -39,9 +39,9 @@ predict_classprob.model_fit <- function (object, new_data, ...) { res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit predict_classprob <- function (object, ...) UseMethod("predict_classprob") diff --git a/R/predict_interval.R b/R/predict_interval.R index 390e97936..9203324ad 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -1,13 +1,13 @@ -#' @keywords internal -#' @rdname other_predict -#' @param level A single numeric value between zero and one for the -#' interval estimates. -#' @param std_error A single logical for wether the standard error should be -#' returned (assuming that the model can compute it). -#' @inheritParams predict.model_fit -#' @method predict_confint model_fit -#' @export predict_confint.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @param level A single numeric value between zero and one for the +# interval estimates. +# @param std_error A single logical for wether the standard error should be +# returned (assuming that the model can compute it). +# @inheritParams predict.model_fit +# @method predict_confint model_fit +# @export predict_confint.model_fit +# @export predict_confint.model_fit <- function (object, new_data, level = 0.95, std_error = FALSE, ...) { @@ -38,21 +38,21 @@ predict_confint.model_fit <- res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit predict_confint <- function (object, ...) UseMethod("predict_confint") ################################################################## -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_predint model_fit -#' @export predict_predint.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_predint model_fit +# @export predict_predint.model_fit +# @export predict_predint.model_fit <- function (object, new_data, level = 0.95, std_error = FALSE, ...) { @@ -84,10 +84,10 @@ predict_predint.model_fit <- res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit predict_predint <- function (object, ...) UseMethod("predict_predint") diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 054fc3eb5..945eb92a7 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -1,9 +1,9 @@ -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit -#' @method predict_numeric model_fit -#' @export predict_numeric.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit +# @method predict_numeric model_fit +# @export predict_numeric.model_fit +# @export predict_numeric.model_fit <- function (object, new_data, ...) { if (object$spec$mode != "regression") @@ -40,9 +40,9 @@ predict_numeric.model_fit <- function (object, new_data, ...) { } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict_numeric.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict_numeric.model_fit predict_numeric <- function (object, ...) UseMethod("predict_numeric") diff --git a/R/predict_quantile.R b/R/predict_quantile.R index ed8cfdbe3..17c9786de 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,11 +1,11 @@ -#' @keywords internal -#' @rdname other_predict -#' @param quant A vector of numbers between 0 and 1 for the quantile being -#' predicted. -#' @inheritParams predict.model_fit -#' @method predict_quantile model_fit -#' @export predict_quantile.model_fit -#' @export +# @keywords internal +# @rdname other_predict +# @param quant A vector of numbers between 0 and 1 for the quantile being +# predicted. +# @inheritParams predict.model_fit +# @method predict_quantile model_fit +# @export predict_quantile.model_fit +# @export predict_quantile.model_fit <- function (object, new_data, quantile = (1:9)/10, ...) { @@ -33,9 +33,9 @@ predict_quantile.model_fit <- res } -#' @export -#' @keywords internal -#' @rdname other_predict -#' @inheritParams predict.model_fit +# @export +# @keywords internal +# @rdname other_predict +# @inheritParams predict.model_fit predict_quantile <- function (object, ...) UseMethod("predict_quantile") diff --git a/R/predict_raw.R b/R/predict_raw.R index c3f7dee25..4d972fac3 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -1,8 +1,8 @@ -#' @rdname predict.model_fit -#' @inheritParams predict.model_fit -#' @method predict_raw model_fit -#' @export predict_raw.model_fit -#' @export +# @rdname predict.model_fit +# @inheritParams predict.model_fit +# @method predict_raw model_fit +# @export predict_raw.model_fit +# @export predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { protected_args <- names(object$spec$method$raw$args) dup_args <- names(opts) %in% protected_args @@ -32,8 +32,8 @@ predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { } -#' @export -#' @rdname predict.model_fit -#' @inheritParams predict_raw.model_fit +# @export +# @rdname predict.model_fit +# @inheritParams predict_raw.model_fit predict_raw <- function (object, ...) UseMethod("predict_raw") diff --git a/man/other_predict.Rd b/man/other_predict.Rd deleted file mode 100644 index 57e3bf3f2..000000000 --- a/man/other_predict.Rd +++ /dev/null @@ -1,67 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict_class.R, R/predict_classprob.R, -% R/predict_interval.R, R/predict_numeric.R, R/predict_quantile.R -\name{predict_class.model_fit} -\alias{predict_class.model_fit} -\alias{predict_class} -\alias{predict_classprob.model_fit} -\alias{predict_classprob} -\alias{predict_confint.model_fit} -\alias{predict_confint} -\alias{predict_predint.model_fit} -\alias{predict_predint} -\alias{predict_numeric.model_fit} -\alias{predict_numeric} -\alias{predict_quantile.model_fit} -\alias{predict_quantile} -\title{Other predict methods.} -\usage{ -\method{predict_class}{model_fit}(object, new_data, ...) - -predict_class(object, ...) - -\method{predict_classprob}{model_fit}(object, new_data, ...) - -predict_classprob(object, ...) - -\method{predict_confint}{model_fit}(object, new_data, level = 0.95, - std_error = FALSE, ...) - -predict_confint(object, ...) - -\method{predict_predint}{model_fit}(object, new_data, level = 0.95, - std_error = FALSE, ...) - -predict_predint(object, ...) - -\method{predict_numeric}{model_fit}(object, new_data, ...) - -predict_numeric(object, ...) - -\method{predict_quantile}{model_fit}(object, new_data, - quantile = (1:9)/10, ...) - -predict_quantile(object, ...) -} -\arguments{ -\item{object}{An object of class \code{model_fit}} - -\item{new_data}{A rectangular data object, such as a data frame.} - -\item{...}{Ignored. To pass arguments to pass to the underlying -function when \code{predict.model_fit(type = "raw")}, -use the \code{opts} argument.} - -\item{level}{A single numeric value between zero and one for the -interval estimates.} - -\item{std_error}{A single logical for wether the standard error should be -returned (assuming that the model can compute it).} - -\item{quant}{A vector of numbers between 0 and 1 for the quantile being -predicted.} -} -\description{ -These are internal functions not meant to be directly called by the user. -} -\keyword{internal} diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index 78999cbd0..e093f83a7 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -1,17 +1,11 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/predict.R, R/predict_raw.R +% Please edit documentation in R/predict.R \name{predict.model_fit} \alias{predict.model_fit} -\alias{predict_raw.model_fit} -\alias{predict_raw} \title{Model predictions} \usage{ \method{predict}{model_fit}(object, new_data, type = NULL, opts = list(), ...) - -\method{predict_raw}{model_fit}(object, new_data, opts = list(), ...) - -predict_raw(object, ...) } \arguments{ \item{object}{An object of class \code{model_fit}} @@ -54,9 +48,8 @@ Quantile predictions return a tibble with a column \code{.pred}, which is a list-column. Each list element contains a tibble with columns \code{.pred} and \code{.quantile} (and perhaps other columns). -Using \code{type = "raw"} with \code{predict.model_fit()} (or using -\code{predict_raw()}) will return the unadulterated results of the -prediction function. +Using \code{type = "raw"} with \code{predict.model_fit()} will return +the unadulterated results of the prediction function. In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R index 81e30fd62..3d1f0e911 100644 --- a/tests/testthat/test_boost_tree_C50.R +++ b/tests/testthat/test_boost_tree_C50.R @@ -80,7 +80,7 @@ test_that('C5.0 prediction', { ) xy_pred <- predict(classes_xy$fit, newdata = lending_club[1:7, num_pred]) - expect_equal(xy_pred, predict_class(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(classes_xy, lending_club[1:7, num_pred])) }) @@ -97,9 +97,9 @@ test_that('C5.0 probabilities', { xy_pred <- predict(classes_xy$fit, newdata = as.data.frame(lending_club[1:7, num_pred]), type = "prob") xy_pred <- as_tibble(xy_pred) - expect_equal(xy_pred, predict_classprob(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(classes_xy, lending_club[1:7, num_pred])) - one_row <- predict_classprob(classes_xy, lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(classes_xy, lending_club[1, num_pred]) expect_equal(xy_pred[1,], one_row) }) diff --git a/tests/testthat/test_boost_tree_spark.R b/tests/testthat/test_boost_tree_spark.R index ce148620c..49e9892ec 100644 --- a/tests/testthat/test_boost_tree_spark.R +++ b/tests/testthat/test_boost_tree_spark.R @@ -58,7 +58,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_bt_te), + spark_reg_pred_num <- parsnip:::predict_numeric(spark_reg_fit, iris_bt_te), regexp = NA ) @@ -68,7 +68,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_bt_te), + spark_reg_num_dup <- parsnip:::predict_numeric(spark_reg_fit_dup, iris_bt_te), regexp = NA ) @@ -124,7 +124,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, churn_bt_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, churn_bt_te), regexp = NA ) @@ -134,7 +134,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_class <- predict_class(spark_class_fit_dup, churn_bt_te), + spark_class_dup_class <- parsnip:::predict_class(spark_class_fit_dup, churn_bt_te), regexp = NA ) @@ -156,7 +156,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, churn_bt_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, churn_bt_te), regexp = NA ) @@ -166,7 +166,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_classprob <- predict_classprob(spark_class_fit_dup, churn_bt_te), + spark_class_dup_classprob <- parsnip:::predict_classprob(spark_class_fit_dup, churn_bt_te), regexp = NA ) diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 0c6be9417..e740cfa12 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -66,7 +66,7 @@ test_that('xgboost classification prediction', { xy_pred <- predict(xy_fit$fit, newdata = xgb.DMatrix(data = as.matrix(iris[1:8, num_pred])), type = "class") xy_pred <- matrix(xy_pred, ncol = 3, byrow = TRUE) xy_pred <- factor(levels(iris$Species)[apply(xy_pred, 1, which.max)], levels = levels(iris$Species)) - expect_equal(xy_pred, predict_class(xy_fit, new_data = iris[1:8, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = iris[1:8, num_pred])) form_fit <- fit( iris_xgboost, @@ -78,7 +78,7 @@ test_that('xgboost classification prediction', { form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(iris[1:8, num_pred])), type = "class") form_pred <- matrix(form_pred, ncol = 3, byrow = TRUE) form_pred <- factor(levels(iris$Species)[apply(form_pred, 1, which.max)], levels = levels(iris$Species)) - expect_equal(form_pred, predict_class(form_fit, new_data = iris[1:8, num_pred])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = iris[1:8, num_pred])) }) @@ -141,7 +141,7 @@ test_that('xgboost regression prediction', { ) xy_pred <- predict(xy_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) form_fit <- fit( car_basic, @@ -151,7 +151,7 @@ test_that('xgboost regression prediction', { ) form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) - expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1])) + expect_equal(form_pred, parsnip:::predict_numeric(form_fit, new_data = mtcars[1:8, -1])) }) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 86d2e2bed..ba5f804b0 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -269,7 +269,7 @@ test_that('lm prediction', { control = ctrl ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -277,7 +277,7 @@ test_that('lm prediction', { data = iris, control = ctrl ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ])) res_mv <- fit( iris_basic, @@ -285,7 +285,7 @@ test_that('lm prediction', { data = iris, control = ctrl ) - expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,])) + expect_equal(mv_pred, parsnip:::predict_numeric(res_mv, iris[1:5,])) }) test_that('lm intervals', { diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index a22b6e73d..6c58182b2 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -69,7 +69,7 @@ test_that('glmnet prediction, single lambda', { s = iris_basic$spec$args$penalty) uni_pred <- unname(uni_pred[,1]) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -87,7 +87,7 @@ test_that('glmnet prediction, single lambda', { s = res_form$spec$spec$args$penalty) form_pred <- unname(form_pred[,1]) - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) @@ -115,7 +115,7 @@ test_that('glmnet prediction, multiple lambda', { mult_pred$lambda <- rep(lams, each = 5) mult_pred <- mult_pred[,-2] - expect_equal(mult_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(mult_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_mult, @@ -135,7 +135,7 @@ test_that('glmnet prediction, multiple lambda', { form_pred$lambda <- rep(lams, each = 5) form_pred <- form_pred[,-2] - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) test_that('glmnet prediction, all lambda', { @@ -157,7 +157,7 @@ test_that('glmnet prediction, all lambda', { all_pred$lambda <- rep(res_xy$fit$lambda, each = 5) all_pred <- all_pred[,-2] - expect_equal(all_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(all_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) # test that the lambda seq is in the right order (since no docs on this) tmp_pred <- predict(res_xy$fit, newx = as.matrix(iris[1:5, num_pred]), @@ -180,7 +180,7 @@ test_that('glmnet prediction, all lambda', { form_pred$lambda <- rep(res_form$fit$lambda, each = 5) form_pred <- form_pred[,-2] - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) diff --git a/tests/testthat/test_linear_reg_spark.R b/tests/testthat/test_linear_reg_spark.R index f1a033beb..28859a030 100644 --- a/tests/testthat/test_linear_reg_spark.R +++ b/tests/testthat/test_linear_reg_spark.R @@ -42,7 +42,7 @@ test_that('spark execution', { ) expect_error( - spark_pred_num <- predict_numeric(spark_fit, iris_linreg_te), + spark_pred_num <- parsnip:::predict_numeric(spark_fit, iris_linreg_te), regexp = NA ) diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R index 4891e65e1..8fff7084e 100644 --- a/tests/testthat/test_linear_reg_stan.R +++ b/tests/testthat/test_linear_reg_stan.R @@ -67,7 +67,7 @@ test_that('stan prediction', { control = quiet_ctrl ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]), tolerance = 0.001) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred]), tolerance = 0.001) res_form <- fit( iris_basic, @@ -75,7 +75,7 @@ test_that('stan prediction', { data = iris, control = quiet_ctrl ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ]), tolerance = 0.001) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ]), tolerance = 0.001) }) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index c74b0c492..8778a2f3e 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -275,7 +275,7 @@ test_that('glm prediction', { xy_pred <- ifelse(xy_pred >= 0.5, "good", "bad") xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_class(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(classes_xy, lending_club[1:7, num_pred])) }) @@ -289,9 +289,9 @@ test_that('glm probabilities', { xy_pred <- predict(classes_xy$fit, newdata = lending_club[1:7, num_pred], type = "response") xy_pred <- tibble(bad = 1 - xy_pred, good = xy_pred) - expect_equal(xy_pred, predict_classprob(classes_xy, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(classes_xy, lending_club[1:7, num_pred])) - one_row <- predict_classprob(classes_xy, lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(classes_xy, lending_club[1, num_pred]) expect_equal(xy_pred[1,], one_row) }) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index e59533209..510ace9c9 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -64,7 +64,7 @@ test_that('glmnet prediction, one lambda', { uni_pred <- factor(uni_pred, levels = levels(lending_club$Class)) uni_pred <- unname(uni_pred) - expect_equal(uni_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(uni_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), @@ -84,7 +84,7 @@ test_that('glmnet prediction, one lambda', { form_pred <- factor(form_pred, levels = levels(lending_club$Class)) form_pred <- unname(form_pred) - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -112,7 +112,7 @@ test_that('glmnet prediction, mulitiple lambda', { mult_pred$lambda <- rep(lams, each = 7) mult_pred <- mult_pred[, -2] - expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = lams) %>% set_engine("glmnet"), @@ -134,7 +134,7 @@ test_that('glmnet prediction, mulitiple lambda', { form_pred$lambda <- rep(lams, each = 7) form_pred <- form_pred[, -2] - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -159,7 +159,7 @@ test_that('glmnet prediction, no lambda', { mult_pred$lambda <- rep(xy_fit$fit$lambda, each = 7) mult_pred <- mult_pred[, -2] - expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% set_engine("glmnet", nlambda = 11), @@ -180,7 +180,7 @@ test_that('glmnet prediction, no lambda', { form_pred$values <- factor(form_pred$values, levels = levels(lending_club$Class)) form_pred$lambda <- rep(res_form$fit$lambda, each = 7) form_pred <- form_pred[, -2] - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -202,7 +202,7 @@ test_that('glmnet probabilities, one lambda', { s = 0.1, type = "response")[,1] uni_pred <- tibble(bad = 1 - uni_pred, good = uni_pred) - expect_equal(uni_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(uni_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), @@ -219,9 +219,9 @@ test_that('glmnet probabilities, one lambda', { newx = form_mat, s = 0.1, type = "response")[, 1] form_pred <- tibble(bad = 1 - form_pred, good = form_pred) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) - one_row <- predict_classprob(res_form, lending_club[1, c("funded_amnt", "int_rate")]) + one_row <- parsnip:::predict_classprob(res_form, lending_club[1, c("funded_amnt", "int_rate")]) expect_equal(form_pred[1,], one_row) }) @@ -247,7 +247,7 @@ test_that('glmnet probabilities, mulitiple lambda', { mult_pred <- tibble(bad = 1 - mult_pred$values, good = mult_pred$values) mult_pred$lambda <- rep(lams, each = 7) - expect_equal(mult_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg(penalty = lams) %>% set_engine("glmnet"), @@ -267,7 +267,7 @@ test_that('glmnet probabilities, mulitiple lambda', { form_pred <- tibble(bad = 1 - form_pred$values, good = form_pred$values) form_pred$lambda <- rep(lams, each = 7) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -291,7 +291,7 @@ test_that('glmnet probabilities, no lambda', { mult_pred <- tibble(bad = 1 - mult_pred$values, good = mult_pred$values) mult_pred$lambda <- rep(xy_fit$fit$lambda, each = 7) - expect_equal(mult_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(mult_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% set_engine("glmnet"), @@ -311,7 +311,7 @@ test_that('glmnet probabilities, no lambda', { form_pred <- tibble(bad = 1 - form_pred$values, good = form_pred$values) form_pred$lambda <- rep(res_form$fit$lambda, each = 7) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index b338f4ea6..eb3238902 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -161,7 +161,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>% + keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) @@ -178,7 +178,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>% + keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -1], type = "prob") diff --git a/tests/testthat/test_logistic_reg_spark.R b/tests/testthat/test_logistic_reg_spark.R index 9d435dc85..b9ac7a1fa 100644 --- a/tests/testthat/test_logistic_reg_spark.R +++ b/tests/testthat/test_logistic_reg_spark.R @@ -56,7 +56,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, churn_logit_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, churn_logit_te), regexp = NA ) @@ -73,7 +73,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, churn_logit_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, churn_logit_te), regexp = NA ) diff --git a/tests/testthat/test_logistic_reg_stan.R b/tests/testthat/test_logistic_reg_stan.R index 0a322d1e4..261850588 100644 --- a/tests/testthat/test_logistic_reg_stan.R +++ b/tests/testthat/test_logistic_reg_stan.R @@ -63,7 +63,7 @@ test_that('stan_glm prediction', { xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_class(xy_fit, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% @@ -80,7 +80,7 @@ test_that('stan_glm prediction', { form_pred <- unname(form_pred) form_pred <- ifelse(form_pred >= 0.5, "good", "bad") form_pred <- factor(form_pred, levels = levels(lending_club$Class)) - expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) @@ -102,7 +102,7 @@ test_that('stan_glm probability', { xy_pred <- xy_fit$fit$family$linkinv(xy_pred) xy_pred <- tibble(bad = 1 - xy_pred, good = xy_pred) - expect_equal(xy_pred, predict_classprob(xy_fit, lending_club[1:7, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(xy_fit, lending_club[1:7, num_pred])) res_form <- fit( logistic_reg() %>% @@ -117,7 +117,7 @@ test_that('stan_glm probability', { newdata = lending_club[1:7, c("funded_amnt", "int_rate")]) form_pred <- xy_fit$fit$family$linkinv(form_pred) form_pred <- tibble(bad = 1 - form_pred, good = form_pred) - expect_equal(form_pred, predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) }) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index b23330516..7ea9213a3 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -185,7 +185,7 @@ test_that('mars prediction', { control = ctrl ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -193,7 +193,7 @@ test_that('mars prediction', { data = iris, control = ctrl ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ])) res_mv <- fit( iris_basic, @@ -201,7 +201,7 @@ test_that('mars prediction', { data = iris, control = ctrl ) - expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,])) + expect_equal(mv_pred, parsnip:::predict_numeric(res_mv, iris[1:5,])) }) @@ -270,7 +270,7 @@ test_that('classification', { regexp = NA ) expect_true(!is.null(glm_mars$fit$glm.list)) - parsnip_pred <- predict_classprob(glm_mars, new_data = lending_club[1:5, -ncol(lending_club)]) + parsnip_pred <- parsnip:::predict_classprob(glm_mars, new_data = lending_club[1:5, -ncol(lending_club)]) earth_pred <- c(0.95631355972526, 0.971917781277731, 0.894245392500336, 0.962667553751077, diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index a2022923e..fead31752 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -69,7 +69,7 @@ test_that('keras classification prediction', { control = ctrl ) - xy_pred <- predict_classes(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) + xy_pred <- keras::predict_classes(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) xy_pred <- factor(levels(iris$Species)[xy_pred + 1], levels = levels(iris$Species)) expect_equal(xy_pred, predict(xy_fit, new_data = iris[1:8, num_pred], type = "class")[[".pred_class"]]) @@ -82,7 +82,7 @@ test_that('keras classification prediction', { control = ctrl ) - form_pred <- predict_classes(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) + form_pred <- keras::predict_classes(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) form_pred <- factor(levels(iris$Species)[form_pred + 1], levels = levels(iris$Species)) expect_equal(form_pred, predict(form_fit, new_data = iris[1:8, num_pred], type = "class")[[".pred_class"]]) @@ -101,7 +101,7 @@ test_that('keras classification probabilities', { control = ctrl ) - xy_pred <- predict_proba(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) + xy_pred <- keras::predict_proba(xy_fit$fit, x = as.matrix(iris[1:8, num_pred])) xy_pred <- as_tibble(xy_pred) colnames(xy_pred) <- paste0(".pred_", levels(iris$Species)) expect_equal(xy_pred, predict(xy_fit, new_data = iris[1:8, num_pred], type = "prob")) @@ -115,7 +115,7 @@ test_that('keras classification probabilities', { control = ctrl ) - form_pred <- predict_proba(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) + form_pred <- keras::predict_proba(form_fit$fit, x = as.matrix(iris[1:8, num_pred])) form_pred <- as_tibble(form_pred) colnames(form_pred) <- paste0(".pred_", levels(iris$Species)) expect_equal(form_pred, predict(form_fit, new_data = iris[1:8, num_pred], type = "prob")) @@ -218,7 +218,7 @@ test_that('multivariate nnet formula', { data = nn_dat[-(1:5),] ) expect_equal(length(unlist(keras::get_weights(nnet_form$fit))), 24) - nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_pred <- parsnip:::predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_pred), 3) expect_equal(nrow(nnet_form_pred), 5) expect_equal(names(nnet_form_pred), c("V1", "V2", "V3")) @@ -233,7 +233,7 @@ test_that('multivariate nnet formula', { y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(unlist(keras::get_weights(nnet_xy$fit))), 24) - nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_xy <- parsnip:::predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_xy), 3) expect_equal(nrow(nnet_form_xy), 5) expect_equal(names(nnet_form_xy), c("V1", "V2", "V3")) diff --git a/tests/testthat/test_mlp_nnet.R b/tests/testthat/test_mlp_nnet.R index a112086fe..6a5022786 100644 --- a/tests/testthat/test_mlp_nnet.R +++ b/tests/testthat/test_mlp_nnet.R @@ -64,7 +64,7 @@ test_that('nnet classification prediction', { xy_pred <- predict(xy_fit$fit, newdata = iris[1:8, num_pred], type = "class") xy_pred <- factor(xy_pred, levels = levels(iris$Species)) - expect_equal(xy_pred, predict_class(xy_fit, new_data = iris[1:8, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = iris[1:8, num_pred])) form_fit <- fit( iris_nnet, @@ -75,7 +75,7 @@ test_that('nnet classification prediction', { form_pred <- predict(form_fit$fit, newdata = iris[1:8, num_pred], type = "class") form_pred <- factor(form_pred, levels = levels(iris$Species)) - expect_equal(form_pred, predict_class(form_fit, new_data = iris[1:8, num_pred])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = iris[1:8, num_pred])) }) @@ -141,7 +141,7 @@ test_that('nnet regression prediction', { xy_pred <- predict(xy_fit$fit, newdata = mtcars[1:8, -1])[,1] xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) form_fit <- fit( car_basic, @@ -152,7 +152,7 @@ test_that('nnet regression prediction', { form_pred <- predict(form_fit$fit, newdata = mtcars[1:8, -1])[,1] form_pred <- unname(form_pred) - expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1])) + expect_equal(form_pred, parsnip:::predict_numeric(form_fit, new_data = mtcars[1:8, -1])) }) # ------------------------------------------------------------------------------ @@ -175,7 +175,7 @@ test_that('multivariate nnet formula', { data = nn_dat[-(1:5),] ) expect_equal(length(nnet_form$fit$wts), 24) - nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_pred <- parsnip:::predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_pred), 3) expect_equal(nrow(nnet_form_pred), 5) expect_equal(names(nnet_form_pred), c("V1", "V2", "V3")) @@ -192,7 +192,7 @@ test_that('multivariate nnet formula', { y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(nnet_xy$fit$wts), 24) - nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_xy <- parsnip:::predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_xy), 3) expect_equal(nrow(nnet_form_xy), 5) expect_equal(names(nnet_form_xy), c("V1", "V2", "V3")) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index 6043aae6e..4043bfc21 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -58,7 +58,7 @@ test_that('glmnet prediction, one lambda', { uni_pred <- factor(uni_pred[,1], levels = levels(iris$Species)) uni_pred <- unname(uni_pred) - expect_equal(uni_pred, predict_class(xy_fit, iris[rows, 1:4])) + expect_equal(uni_pred, parsnip:::predict_class(xy_fit, iris[rows, 1:4])) expect_equal(uni_pred, predict(xy_fit, iris[rows, 1:4], type = "class")$.pred_class) res_form <- fit( @@ -77,7 +77,7 @@ test_that('glmnet prediction, one lambda', { s = res_form$spec$args$penalty, type = "class") form_pred <- factor(form_pred[,1], levels = levels(iris$Species)) - expect_equal(form_pred, predict_class(res_form, iris[rows, c("Sepal.Width", "Petal.Width")])) + expect_equal(form_pred, parsnip:::predict_class(res_form, iris[rows, c("Sepal.Width", "Petal.Width")])) expect_equal(form_pred, predict(res_form, iris[rows, c("Sepal.Width", "Petal.Width")], type = "class")$.pred_class) }) diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index a6937e4d6..eafda5871 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -154,7 +154,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>% + keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) @@ -171,7 +171,7 @@ test_that('classification probabilities', { ) keras_pred <- - predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>% + keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>% as_tibble() %>% setNames(paste0(".pred_", lr_fit$lvl)) parsnip_pred <- predict(plrfit, te_dat[, -5], type = "prob") diff --git a/tests/testthat/test_multinom_reg_spark.R b/tests/testthat/test_multinom_reg_spark.R index 0954e52a0..e28238207 100644 --- a/tests/testthat/test_multinom_reg_spark.R +++ b/tests/testthat/test_multinom_reg_spark.R @@ -45,7 +45,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, iris_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, iris_te), regexp = NA ) @@ -62,7 +62,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, iris_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, iris_te), regexp = NA ) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 1d764d0aa..52474f692 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -74,7 +74,7 @@ test_that('kknn prediction', { newdata = iris[1:5, num_pred] ) - expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, parsnip:::predict_numeric(res_xy, iris[1:5, num_pred])) # nominal res_xy_nom <- fit_xy( @@ -89,7 +89,7 @@ test_that('kknn prediction', { newdata = iris[1:5, c("Sepal.Length", "Petal.Width")] ) - expect_equal(uni_pred_nom, predict_class(res_xy_nom, iris[1:5, c("Sepal.Length", "Petal.Width")])) + expect_equal(uni_pred_nom, parsnip:::predict_class(res_xy_nom, iris[1:5, c("Sepal.Length", "Petal.Width")])) # continuous - formula interface res_form <- fit( @@ -104,5 +104,5 @@ test_that('kknn prediction', { newdata = iris[1:5,] ) - expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, parsnip:::predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) diff --git a/tests/testthat/test_nullmodel.R b/tests/testthat/test_nullmodel.R index 45a6a9932..da0294d1f 100644 --- a/tests/testthat/test_nullmodel.R +++ b/tests/testthat/test_nullmodel.R @@ -109,7 +109,7 @@ test_that('nullmodel prediction', { Petal.Length ~ log(Sepal.Width) + Species, data = iris ) - expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) + expect_equal(inl_pred, parsnip:::predict_numeric(res_form, iris[1:5, ])) # Multivariate y res <- fit( @@ -118,7 +118,7 @@ test_that('nullmodel prediction', { data = mtcars ) - expect_equal(mw_pred, predict_numeric(res, mtcars[1:5, ])) + expect_equal(mw_pred, parsnip:::predict_numeric(res, mtcars[1:5, ])) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index cd10d2add..f4e63e2fd 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -31,30 +31,30 @@ lr_fit_2 <- test_that('regression predictions', { expect_true(is_tibble(predict(lm_fit, new_data = iris[1:5,-1]))) - expect_true(is.vector(predict_numeric(lm_fit, new_data = iris[1:5,-1]))) + expect_true(is.vector(parsnip:::predict_numeric(lm_fit, new_data = iris[1:5,-1]))) expect_equal(names(predict(lm_fit, new_data = iris[1:5,-1])), ".pred") }) test_that('classification predictions', { expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) - expect_true(is.factor(predict_class(lr_fit, new_data = class_dat[1:5,-1]))) + expect_true(is.factor(parsnip:::predict_class(lr_fit, new_data = class_dat[1:5,-1]))) expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1])), ".pred_class") expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1], type = "prob"))) - expect_true(is_tibble(predict_classprob(lr_fit, new_data = class_dat[1:5,-1]))) + expect_true(is_tibble(parsnip:::predict_classprob(lr_fit, new_data = class_dat[1:5,-1]))) expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1], type = "prob")), c(".pred_high", ".pred_low")) }) test_that('non-standard levels', { expect_true(is_tibble(predict(lr_fit, new_data = class_dat[1:5,-1]))) - expect_true(is.factor(predict_class(lr_fit, new_data = class_dat[1:5,-1]))) + expect_true(is.factor(parsnip:::predict_class(lr_fit, new_data = class_dat[1:5,-1]))) expect_equal(names(predict(lr_fit, new_data = class_dat[1:5,-1])), ".pred_class") expect_true(is_tibble(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob"))) - expect_true(is_tibble(predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1]))) + expect_true(is_tibble(parsnip:::predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1]))) expect_equal(names(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")), c(".pred_2low", ".pred_high+values")) - expect_equal(names(predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1])), + expect_equal(names(parsnip:::predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1])), c("2low", "high+values")) }) diff --git a/tests/testthat/test_rand_forest_randomForest.R b/tests/testthat/test_rand_forest_randomForest.R index 35937b244..cfba216b7 100644 --- a/tests/testthat/test_rand_forest_randomForest.R +++ b/tests/testthat/test_rand_forest_randomForest.R @@ -90,7 +90,7 @@ test_that('randomForest classification prediction', { xy_pred <- predict(xy_fit$fit, newdata = lending_club[1:6, num_pred]) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) form_fit <- fit( lc_basic, @@ -101,7 +101,7 @@ test_that('randomForest classification prediction', { form_pred <- predict(form_fit$fit, newdata = lending_club[1:6, c("funded_amnt", "int_rate")]) form_pred <- unname(form_pred) - expect_equal(form_pred, predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) }) test_that('randomForest classification probabilities', { @@ -117,9 +117,9 @@ test_that('randomForest classification probabilities', { xy_pred <- predict(xy_fit$fit, newdata = lending_club[1:6, num_pred], type = "prob") xy_pred <- as_tibble(as.data.frame(xy_pred)) - expect_equal(xy_pred, predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) - one_row <- predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) expect_equivalent(xy_pred[1,], one_row) form_fit <- fit( @@ -131,7 +131,7 @@ test_that('randomForest classification probabilities', { form_pred <- predict(form_fit$fit, newdata = lending_club[1:6, c("funded_amnt", "int_rate")], type = "prob") form_pred <- as_tibble(as.data.frame(form_pred)) - expect_equal(form_pred, predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) }) @@ -209,6 +209,6 @@ test_that('randomForest regression prediction', { xy_pred <- predict(xy_fit$fit, newdata = tail(mtcars)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars))) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = tail(mtcars))) }) diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 82d767c4b..3e2f963d7 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -96,7 +96,7 @@ test_that('ranger classification prediction', { xy_pred <- predict(xy_fit$fit, data = lending_club[1:6, num_pred])$prediction xy_pred <- colnames(xy_pred)[apply(xy_pred, 1, which.max)] xy_pred <- factor(xy_pred, levels = levels(lending_club$Class)) - expect_equal(xy_pred, predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) form_fit <- fit( rand_forest() %>% set_engine("ranger"), @@ -109,7 +109,7 @@ test_that('ranger classification prediction', { form_pred <- predict(form_fit$fit, data = lending_club[1:6, c("funded_amnt", "int_rate")])$prediction form_pred <- colnames(form_pred)[apply(form_pred, 1, which.max)] form_pred <- factor(form_pred, levels = levels(lending_club$Class)) - expect_equal(form_pred, predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_class(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) }) @@ -128,9 +128,9 @@ test_that('ranger classification probabilities', { xy_pred <- predict(xy_fit$fit, data = lending_club[1:6, num_pred])$predictions xy_pred <- as_tibble(xy_pred) - expect_equal(xy_pred, predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) + expect_equal(xy_pred, parsnip:::predict_classprob(xy_fit, new_data = lending_club[1:6, num_pred])) - one_row <- predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) + one_row <- parsnip:::predict_classprob(xy_fit, new_data = lending_club[1, num_pred]) expect_equivalent(xy_pred[1,], one_row) form_fit <- fit( @@ -143,7 +143,7 @@ test_that('ranger classification probabilities', { form_pred <- predict(form_fit$fit, data = lending_club[1:6, c("funded_amnt", "int_rate")])$predictions form_pred <- as_tibble(form_pred) - expect_equal(form_pred, predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) + expect_equal(form_pred, parsnip:::predict_classprob(form_fit, new_data = lending_club[1:6, c("funded_amnt", "int_rate")])) no_prob_model <- fit_xy( rand_forest() %>% set_engine("ranger", probability = FALSE), @@ -154,7 +154,7 @@ test_that('ranger classification probabilities', { ) expect_error( - predict_classprob(no_prob_model, new_data = lending_club[1:6, num_pred]) + parsnip:::predict_classprob(no_prob_model, new_data = lending_club[1:6, num_pred]) ) }) @@ -229,7 +229,7 @@ test_that('ranger regression prediction', { xy_pred <- predict(xy_fit$fit, data = tail(mtcars[, -1]))$prediction - expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars[, -1]))) + expect_equal(xy_pred, parsnip:::predict_numeric(xy_fit, new_data = tail(mtcars[, -1]))) }) diff --git a/tests/testthat/test_rand_forest_spark.R b/tests/testthat/test_rand_forest_spark.R index 4184e6abf..da0b5fab6 100644 --- a/tests/testthat/test_rand_forest_spark.R +++ b/tests/testthat/test_rand_forest_spark.R @@ -58,7 +58,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_rf_te), + spark_reg_pred_num <- parsnip:::predict_numeric(spark_reg_fit, iris_rf_te), regexp = NA ) @@ -68,7 +68,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_rf_te), + spark_reg_num_dup <- parsnip:::predict_numeric(spark_reg_fit_dup, iris_rf_te), regexp = NA ) @@ -124,7 +124,7 @@ test_that('spark execution', { ) expect_error( - spark_class_pred_class <- predict_class(spark_class_fit, churn_rf_te), + spark_class_pred_class <- parsnip:::predict_class(spark_class_fit, churn_rf_te), regexp = NA ) @@ -134,7 +134,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_class <- predict_class(spark_class_fit_dup, churn_rf_te), + spark_class_dup_class <- parsnip:::predict_class(spark_class_fit_dup, churn_rf_te), regexp = NA ) @@ -156,7 +156,7 @@ test_that('spark execution', { ) expect_error( - spark_class_prob_classprob <- predict_classprob(spark_class_fit, churn_rf_te), + spark_class_prob_classprob <- parsnip:::predict_classprob(spark_class_fit, churn_rf_te), regexp = NA ) @@ -166,7 +166,7 @@ test_that('spark execution', { ) expect_error( - spark_class_dup_classprob <- predict_classprob(spark_class_fit_dup, churn_rf_te), + spark_class_dup_classprob <- parsnip:::predict_classprob(spark_class_fit_dup, churn_rf_te), regexp = NA ) From 24ba3f9e3d33023de18016a112019c4b5e301eee Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 25 Feb 2019 23:12:24 -0500 Subject: [PATCH 02/15] updates for #156 --- NEWS.md | 6 +++ R/logistic_reg_data.R | 52 ++++++++++++++++++------- tests/testthat/test_logistic_reg.R | 6 ++- tests/testthat/test_logistic_reg_stan.R | 12 ++++-- 4 files changed, 56 insertions(+), 20 deletions(-) diff --git a/NEWS.md b/NEWS.md index 0f76cee0c..16b00685d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ ## New Features * A "null model" is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification). + * `fit_xy()` can take a single column data frame or matrix for `y` without error ## Other Changes @@ -11,6 +12,10 @@ of possible varying arguments is returned (as opposed to only the arguments that are actually varying). +* `fit_control()` not returns an S3 method. + +* The prediction modules (e.g. `predict_class`, `predict_numeric`, etc) were de-exported. These were internal functions that were not to be used by the users and the users were using them. + ## Bug Fixes * `varying_args()` now uses the version from the `generics` package. This means @@ -31,6 +36,7 @@ column names once (#107). * For multinomial regression using glmnet, `multi_predict()` now pulls the correct default penalty (#108). +* Confidence and prediction intervals for logistic regression were only computed the intervals for a single level. Both are now computed. (#156) # parsnip 0.0.1 diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index a5aef8bfb..c17c8245b 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -4,7 +4,7 @@ logistic_reg_arg_key <- data.frame( glmnet = c( "lambda", "alpha"), spark = c("reg_param", "elastic_net_param"), stan = c( NA, NA), - keras = c( "decay", NA), + keras = c( "decay", NA), stringsAsFactors = FALSE, row.names = c("penalty", "mixture") ) @@ -77,12 +77,20 @@ logistic_reg_glm_data <- const <- qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) trans <- object$fit$family$linkinv - res <- + res_2 <- tibble( - .pred_lower = trans(results$fit - const * results$se.fit), - .pred_upper = trans(results$fit + const * results$se.fit) + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) ) - if(object$spec$method$confint$extras$std_error) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_", object$lvl, "_lower") + hi_nms <- paste0(".pred_", object$lvl, "_upper") + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$confint$extras$std_error) res$.std_error <- results$se.fit res }, @@ -199,21 +207,29 @@ logistic_reg_stan_data <- confint = list( pre = NULL, post = function(results, object) { - res <- + res_2 <- tibble( - .pred_lower = + lo = convert_stan_interval( results, level = object$spec$method$confint$extras$level ), - .pred_upper = + hi = convert_stan_interval( results, level = object$spec$method$confint$extras$level, lower = FALSE ), ) - if(object$spec$method$confint$extras$std_error) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_", object$lvl, "_lower") + hi_nms <- paste0(".pred_", object$lvl, "_upper") + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$confint$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -229,21 +245,29 @@ logistic_reg_stan_data <- predint = list( pre = NULL, post = function(results, object) { - res <- + res_2 <- tibble( - .pred_lower = + lo = convert_stan_interval( results, level = object$spec$method$predint$extras$level ), - .pred_upper = + hi = convert_stan_interval( results, level = object$spec$method$predint$extras$level, lower = FALSE ), ) - if(object$spec$method$predint$extras$std_error) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_", object$lvl, "_lower") + hi_nms <- paste0(".pred_", object$lvl, "_upper") + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$predint$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -327,4 +351,4 @@ logistic_reg_keras_data <- x = quote(as.matrix(new_data)) ) ) - ) \ No newline at end of file + ) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index 8778a2f3e..2eb74e658 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -323,8 +323,10 @@ test_that('glm intervals', { level = 0.93, std_error = TRUE) - expect_equivalent(confidence_parsnip$.pred_lower, lower_glm) - expect_equivalent(confidence_parsnip$.pred_upper, upper_glm) + expect_equivalent(confidence_parsnip$.pred_good_lower, lower_glm) + expect_equivalent(confidence_parsnip$.pred_good_upper, upper_glm) + expect_equivalent(confidence_parsnip$.pred_bad_lower, 1 - upper_glm) + expect_equivalent(confidence_parsnip$.pred_bad_upper, 1 - lower_glm) expect_equivalent(confidence_parsnip$.std_error, pred_glm$se.fit) }) diff --git a/tests/testthat/test_logistic_reg_stan.R b/tests/testthat/test_logistic_reg_stan.R index 261850588..ec260ffea 100644 --- a/tests/testthat/test_logistic_reg_stan.R +++ b/tests/testthat/test_logistic_reg_stan.R @@ -156,8 +156,10 @@ test_that('stan intervals', { stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) stan_std <- apply(stan_post, 2, sd) - expect_equivalent(confidence_parsnip$.pred_lower, stan_lower) - expect_equivalent(confidence_parsnip$.pred_upper, stan_upper) + expect_equivalent(confidence_parsnip$.pred_good_lower, stan_lower) + expect_equivalent(confidence_parsnip$.pred_good_upper, stan_upper) + expect_equivalent(confidence_parsnip$.pred_bad_lower, 1 - stan_upper) + expect_equivalent(confidence_parsnip$.pred_bad_upper, 1 - stan_lower) expect_equivalent(confidence_parsnip$.std_error, stan_std) stan_pred_post <- @@ -168,8 +170,10 @@ test_that('stan intervals', { stan_pred_upper <- apply(stan_pred_post, 2, quantile, prob = 0.965) stan_pred_std <- apply(stan_pred_post, 2, sd) - expect_equivalent(prediction_parsnip$.pred_lower, stan_pred_lower) - expect_equivalent(prediction_parsnip$.pred_upper, stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_good_lower, stan_pred_lower) + expect_equivalent(prediction_parsnip$.pred_good_upper, stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_bad_lower, 1 - stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_bad_upper, 1 - stan_pred_lower) expect_equivalent(prediction_parsnip$.std_error, stan_pred_std, tolerance = 0.1) }) From 2bdb427cd1ba030e2936a86eb8612f37ed6d6767 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 13:58:11 -0500 Subject: [PATCH 03/15] rewrote man page entry for ... for predict --- man/predict.model_fit.Rd | 23 ++++++++++++++++++++--- tests/testthat/test_predict_formats.R | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index e093f83a7..6a1b0c54a 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -22,9 +22,21 @@ predict function that will be used when \code{type = "raw"}. The list should not include options for the model object or the new data being predicted.} -\item{...}{Ignored. To pass arguments to pass to the underlying -function when \code{predict.model_fit(type = "raw")}, -use the \code{opts} argument.} +\item{...}{Arguments to the underlying model's prediction +function cannot be passed here (see \code{opts}). There are some +\code{parsnip} related options that can be passed, depending on the +value of \code{type}. Possible arguments are: +\itemize{ +\item \code{level}: for \code{type}s of "conf_int" and "pred_int" this +is the parameter for the tail area of the intervals +(e.g. confidence level for confidence intervals). +\item \code{std_error}: add the standard error of fit or +prediction for \code{type}s of "conf_int" and "pred_int". +\item \code{quantile}: the quantile(s) for quantile regression +(not implemented yet) +\item \code{time}: the time(s) for hazard probability estimates +(not implemented yet) +}} } \value{ With the exception of \code{type = "raw"}, the results of @@ -55,6 +67,11 @@ In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots appear in names and 2) vectors are never returned but type-specific prediction functions. + +When the model fit failed and the error was captured, the +\code{predict()} function will return the same structure as above but +filled with missing values. This does not currently work for +multivariate models. } \description{ Apply a model to create different types of predictions. diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index f4e63e2fd..eefa6ab53 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -1,6 +1,7 @@ library(testthat) library(parsnip) library(tibble) +library(dplyr) # ------------------------------------------------------------------------------ @@ -58,3 +59,21 @@ test_that('non-standard levels', { expect_equal(names(parsnip:::predict_classprob(lr_fit_2, new_data = class_dat2[1:5,-1])), c("2low", "high+values")) }) + +# ------------------------------------------------------------------------------ + +test_that('bad predict args', { + lm_model <- + linear_reg() %>% + set_engine("lm") %>% + fit(mpg ~ ., data = mtcars %>% slice(11:32)) + + pred_cars <- + mtcars %>% + slice(1:10) %>% + select(-mpg) + + expect_error(predict(lm_model, pred_cars, yes = "no")) + expect_error(predict(lm_model, pred_cars, type = "conf_int", level = 0.95, yes = "no")) +}) + From b42b43da3ca878006c44336c4d693dc3d3cf900b Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 13:58:45 -0500 Subject: [PATCH 04/15] first attempt at failed models --- R/predict.R | 51 ++++++++++++++++--- R/predict_class.R | 23 +++++++-- R/predict_classprob.R | 24 +++++++-- R/predict_interval.R | 46 ++++++++++++----- R/predict_numeric.R | 27 +++++++++- R/predict_raw.R | 18 ++++--- tests/testthat/test_failed_models.R | 77 +++++++++++++++++++++++++++++ 7 files changed, 229 insertions(+), 37 deletions(-) create mode 100644 tests/testthat/test_failed_models.R diff --git a/R/predict.R b/R/predict.R index 28f4ad505..596de4b9d 100644 --- a/R/predict.R +++ b/R/predict.R @@ -14,9 +14,23 @@ #' predict function that will be used when `type = "raw"`. The #' list should not include options for the model object or the #' new data being predicted. -#' @param ... Ignored. To pass arguments to pass to the underlying -#' function when `predict.model_fit(type = "raw")`, -#' use the `opts` argument. +#' @param ... Arguments to the underlying model's prediction +#' function cannot be passed here (see `opts`). There are some +#' `parsnip` related options that can be passed, depending on the +#' value of `type`. Possible arguments are: +#' \itemize{ +#' \item `level`: for `type`s of "conf_int" and "pred_int" this +#' is the parameter for the tail area of the intervals +#' (e.g. confidence level for confidence intervals). +#' Default value is 0.95. +#' \item `std_error`: add the standard error of fit or +#' prediction for `type`s of "conf_int" and "pred_int". +#' Default value is `FALSE`. +#' \item `quantile`: the quantile(s) for quantile regression +#' (not implemented yet) +#' \item `time`: the time(s) for hazard probability estimates +#' (not implemented yet) +#' } #' @details If "type" is not supplied to `predict()`, then a choice #' is made (`type = "numeric"` for regression models and #' `type = "class"` for classification). @@ -49,7 +63,7 @@ #' a list-column. Each list element contains a tibble with columns #' `.pred` and `.quantile` (and perhaps other columns). #' -#' Using `type = "raw"` with `predict.model_fit()` will return +#' Using `type = "raw"` with `predict.model_fit()` will return #' the unadulterated results of the prediction function. #' #' In the case of Spark-based models, since table columns cannot @@ -57,6 +71,10 @@ #' appear in names and 2) vectors are never returned but #' type-specific prediction functions. #' +#' When the model fit failed and the error was captured, the +#' `predict()` function will return the same structure as above but +#' filled with missing values. This does not currently work for +#' multivariate models. #' @examples #' library(dplyr) #' @@ -89,10 +107,21 @@ #' @method predict model_fit #' @export predict.model_fit #' @export -predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) +predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) { + the_dots <- enquos(...) + if (any(names(the_dots) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + other_args <- c("level", "std_error", "quantile") # "time" for survival probs later + is_pred_arg <- names(the_dots) %in% other_args + if (any(!is_pred_arg)) { + bad_args <- names(the_dots)[!is_pred_arg] + bad_args <- paste0("`", bad_args, "`", collapse = ", ") + stop("The ellipses are not used to pass args to the model function's ", + "predict function. These arguments cannot be used: ", + bad_args, call. = FALSE) + } + type <- check_pred_type(object, type) if (type != "raw" && length(opts) > 0) warning("`opts` is only used with `type = 'raw'` and was ignored.") @@ -220,10 +249,18 @@ multi_predict <- function(object, ...) #' @export #' @rdname multi_predict multi_predict.default <- function(object, ...) - stop ("No `multi_predict` method exists for objects with classes ", + stop("No `multi_predict` method exists for objects with classes ", paste0("'", class(), "'", collapse = ", "), call. = FALSE) #' @export predict.model_spec <- function(object, ...) { stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE) } + + +failed_class <- function(n, lvl) { + tibble(.pred = rep(NA_real_, n)) +} + + + diff --git a/R/predict_class.R b/R/predict_class.R index 39a051e5e..e422a9d7f 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -8,14 +8,18 @@ # @method predict_class model_fit # @export predict_class.model_fit # @export -predict_class.model_fit <- function (object, new_data, ...) { - if(object$spec$mode != "classification") +predict_class.model_fit <- function(object, new_data, ...) { + if (object$spec$mode != "classification") stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) if (!any(names(object$spec$method) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + return(failed_class(n = nrow(new_data), lvl = object$lvl)) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -28,7 +32,7 @@ predict_class.model_fit <- function (object, new_data, ...) { res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$class$post)) { + if (!is.null(object$spec$method$class$post)) { res <- object$spec$method$class$post(res, object) } @@ -47,5 +51,16 @@ predict_class.model_fit <- function (object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_class <- function (object, ...) +predict_class <- function(object, ...) UseMethod("predict_class") + +# ------------------------------------------------------------------------------ + +# Some `predict()` helpers for failed models: + +failed_class <- function(n, lvl) { + res <- rep(NA_character_, n) + res <- factor(res, levels = lvl) + res +} + diff --git a/R/predict_classprob.R b/R/predict_classprob.R index ff481fd58..d0cf628fa 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -5,14 +5,18 @@ # @export predict_classprob.model_fit # @export #' @importFrom tibble as_tibble is_tibble tibble -predict_classprob.model_fit <- function (object, new_data, ...) { - if(object$spec$mode != "classification") +predict_classprob.model_fit <- function(object, new_data, ...) { + if (object$spec$mode != "classification") stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) if (!any(names(object$spec$method) == "classprob")) stop("No class probability module defined for this model.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + return(failed_classprob(n = nrow(new_data), lvl = object$lvl)) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -25,7 +29,7 @@ predict_classprob.model_fit <- function (object, new_data, ...) { res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$classprob$post)) { + if (!is.null(object$spec$method$classprob$post)) { res <- object$spec$method$classprob$post(res, object) } @@ -43,5 +47,17 @@ predict_classprob.model_fit <- function (object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_classprob <- function (object, ...) +predict_classprob <- function(object, ...) UseMethod("predict_classprob") + + +# ------------------------------------------------------------------------------ + +# Some `predict()` helpers for failed models: + +failed_classprob <- function(n, lvl) { + res <- matrix(NA_real_, nrow = n, ncol = length(lvl)) + colnames(res) <- lvl + as_tibble(res) +} + diff --git a/R/predict_interval.R b/R/predict_interval.R index 9203324ad..f71cb3da2 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -8,13 +8,16 @@ # @method predict_confint model_fit # @export predict_confint.model_fit # @export -predict_confint.model_fit <- - function (object, new_data, level = 0.95, std_error = FALSE, ...) { +predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { if (is.null(object$spec$method$confint)) stop("No confidence interval method defined for this ", "engine.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + return(failed_int(n = nrow(new_data), lvl = object$lvl)) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -29,7 +32,7 @@ predict_confint.model_fit <- res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$confint$post)) { + if (!is.null(object$spec$method$confint$post)) { res <- object$spec$method$confint$post(res, object) } @@ -42,10 +45,28 @@ predict_confint.model_fit <- # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_confint <- function (object, ...) +predict_confint <- function(object, ...) UseMethod("predict_confint") -################################################################## +# ------------------------------------------------------------------------------ + +# Some `predict()` helpers for failed models: + +failed_int <- function(n, lvl = NULL, nms = ".pred") { + # TODO figure out multivariate models + if (is.null(lvl)) { + res <- matrix(NA_real_, nrow = n, ncol = length(nms) * 2) + colnames(res) <- c(".pred_lower", ".pred_upper") + } else { + res <- matrix(NA_real_, ncol = length(lvl) * 2, nrow = n) + nms <- expand.grid(c("lower", "upper"), lvl) + nms <- paste(".pred", nms$Var2, nms$Var1, sep = "_") + colnames(res) <- nms + } + as_tibble(res) +} + +# ------------------------------------------------------------------------------ # @keywords internal # @rdname other_predict @@ -53,13 +74,16 @@ predict_confint <- function (object, ...) # @method predict_predint model_fit # @export predict_predint.model_fit # @export -predict_predint.model_fit <- - function (object, new_data, level = 0.95, std_error = FALSE, ...) { +predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { if (is.null(object$spec$method$predint)) stop("No prediction interval method defined for this ", "engine.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + return(failed_int(n = nrow(new_data), lvl = object$lvl)) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -75,7 +99,7 @@ predict_predint.model_fit <- res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$predint$post)) { + if (!is.null(object$spec$method$predint$post)) { res <- object$spec$method$predint$post(res, object) } @@ -88,10 +112,6 @@ predict_predint.model_fit <- # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_predint <- function (object, ...) +predict_predint <- function(object, ...) UseMethod("predict_predint") - - - - diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 945eb92a7..873a0cf36 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -5,7 +5,7 @@ # @export predict_numeric.model_fit # @export -predict_numeric.model_fit <- function (object, new_data, ...) { +predict_numeric.model_fit <- function(object, new_data, ...) { if (object$spec$mode != "regression") stop("`predict_numeric()` is for predicting numeric outcomes. ", "Use `predict_class()` or `predict_classprob()` for ", @@ -14,6 +14,11 @@ predict_numeric.model_fit <- function (object, new_data, ...) { if (!any(names(object$spec$method) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) + if (inherits(object$fit, "try-error")) { + # TODO handle multivariate cases + return(failed_numeric(n = nrow(new_data))) + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -44,5 +49,23 @@ predict_numeric.model_fit <- function (object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict_numeric.model_fit -predict_numeric <- function (object, ...) +predict_numeric <- function(object, ...) UseMethod("predict_numeric") + +# ------------------------------------------------------------------------------ + +# Some `predict()` helpers for failed models: + +failed_numeric <- function(n, nms = ".pred") { + res <- matrix(NA_real_, ncol = length(nms), nrow = n) + if (length(nms) > 1) { + colnames(res) <- nms + res <- as_tibble(res) + } else { + res <- res[,1] + } + res +} + + + diff --git a/R/predict_raw.R b/R/predict_raw.R index 4d972fac3..20a970d56 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -3,7 +3,7 @@ # @method predict_raw model_fit # @export predict_raw.model_fit # @export -predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { +predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { protected_args <- names(object$spec$method$raw$args) dup_args <- names(opts) %in% protected_args if (any(dup_args)) { @@ -13,19 +13,23 @@ predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { object$spec$method$raw$args <- c(object$spec$method$raw$args, opts) } - + if (!any(names(object$spec$method) == "raw")) stop("No raw prediction module defined for this model.", call. = FALSE) - + + if (inherits(object$fit, "try-error")) { + stop("Model fit failed; cannot make predictions.") + } + new_data <- prepare_data(object, new_data) - + # preprocess data if (!is.null(object$spec$method$raw$pre)) new_data <- object$spec$method$raw$pre(new_data, object) - + # create prediction call pred_call <- make_pred_call(object$spec$method$raw) - + res <- eval_tidy(pred_call) res @@ -35,5 +39,5 @@ predict_raw.model_fit <- function (object, new_data, opts = list(), ...) { # @export # @rdname predict.model_fit # @inheritParams predict_raw.model_fit -predict_raw <- function (object, ...) +predict_raw <- function(object, ...) UseMethod("predict_raw") diff --git a/tests/testthat/test_failed_models.R b/tests/testthat/test_failed_models.R new file mode 100644 index 000000000..6d9edc9b2 --- /dev/null +++ b/tests/testthat/test_failed_models.R @@ -0,0 +1,77 @@ +library(testthat) +library(parsnip) +library(dplyr) +library(rlang) + +# ------------------------------------------------------------------------------ + +context("prediciton with failed models") + +# ------------------------------------------------------------------------------ + +iris_bad <- + iris %>% + mutate(big_num = Inf) + +data("lending_club") + +lending_club <- + lending_club %>% + slice(1:200) %>% + mutate(big_num = Inf) + +lvl <- levels(lending_club$Class) + +# ------------------------------------------------------------------------------ + +ctrl <- fit_control(catch = TRUE) + +# ------------------------------------------------------------------------------ + +test_that('numeric model', { + lm_mod <- + linear_reg() %>% + set_engine("lm") %>% + fit(Sepal.Length ~ ., data = iris_bad, control = ctrl) + + num_res <- predict(lm_mod, iris_bad[1:11, -1]) + expect_equal(num_res, tibble(.pred = rep(NA_real_, 11))) + + exp_int_res <- tibble(.pred_lower = rep(NA_real_, 11), .pred_upper = rep(NA_real_, 11)) + ci_res <- predict(lm_mod, iris_bad[1:11, -1], type = "conf_int") + expect_equal(ci_res, exp_int_res) + + pi_res <- predict(lm_mod, iris_bad[1:11, -1], type = "pred_int") + expect_equal(pi_res, exp_int_res) + +}) + +# ------------------------------------------------------------------------------ + +test_that('classification model', { + log_reg <- + logistic_reg() %>% + set_engine("glm") %>% + fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl) + + cls_res <- predict(log_reg, lending_club %>% slice(1:7) %>% dplyr::select(-Class)) + exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 7), levels = lvl)) + expect_equal(cls_res, exp_cls_res) + + prb_res <- + predict(log_reg, lending_club %>% slice(1:7) %>% dplyr::select(-Class), type = "prob") + exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 7), .pred_good = rep(NA_real_, 7)) + expect_equal(prb_res, exp_prb_res) + + ci_res <- + predict(log_reg, lending_club %>% slice(1:7) %>% dplyr::select(-Class), type = "conf_int") + exp_ci_res <- + tibble( + .pred_bad_lower = rep(NA_real_, 7), + .pred_bad_upper = rep(NA_real_, 7), + .pred_good_lower = rep(NA_real_, 7), + .pred_good_upper = rep(NA_real_, 7) + ) + expect_equal(ci_res, exp_ci_res) +}) + From 2e8a809abe4d284f2c2c7e58af0ceb95abd110e1 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 16:21:24 -0500 Subject: [PATCH 05/15] troubleshooting travis failure --- tests/testthat/test_predict_formats.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index eefa6ab53..cc8874487 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -71,7 +71,7 @@ test_that('bad predict args', { pred_cars <- mtcars %>% slice(1:10) %>% - select(-mpg) + dplyr::select(-mpg) expect_error(predict(lm_model, pred_cars, yes = "no")) expect_error(predict(lm_model, pred_cars, type = "conf_int", level = 0.95, yes = "no")) From d9ed3d566449285738070593b6367eefcf5222fe Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 17:12:48 -0500 Subject: [PATCH 06/15] troubleshooting travis failure --- man/predict.model_fit.Rd | 2 ++ tests/testthat/test_predict_formats.R | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index 6a1b0c54a..9999429d8 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -30,8 +30,10 @@ value of \code{type}. Possible arguments are: \item \code{level}: for \code{type}s of "conf_int" and "pred_int" this is the parameter for the tail area of the intervals (e.g. confidence level for confidence intervals). +Default value is 0.95. \item \code{std_error}: add the standard error of fit or prediction for \code{type}s of "conf_int" and "pred_int". +Default value is \code{FALSE}. \item \code{quantile}: the quantile(s) for quantile regression (not implemented yet) \item \code{time}: the time(s) for hazard probability estimates diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index cc8874487..0ea615c71 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -66,11 +66,10 @@ test_that('bad predict args', { lm_model <- linear_reg() %>% set_engine("lm") %>% - fit(mpg ~ ., data = mtcars %>% slice(11:32)) + fit(mpg ~ ., data = mtcars) pred_cars <- mtcars %>% - slice(1:10) %>% dplyr::select(-mpg) expect_error(predict(lm_model, pred_cars, yes = "no")) From a7b76f4bee523602178019906adb4913cea95d9b Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 17:36:42 -0500 Subject: [PATCH 07/15] using namespaced slice to test travis error --- tests/testthat/test_failed_models.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test_failed_models.R b/tests/testthat/test_failed_models.R index 6d9edc9b2..da94394af 100644 --- a/tests/testthat/test_failed_models.R +++ b/tests/testthat/test_failed_models.R @@ -17,7 +17,7 @@ data("lending_club") lending_club <- lending_club %>% - slice(1:200) %>% + dplyr::slice(1:200) %>% mutate(big_num = Inf) lvl <- levels(lending_club$Class) @@ -54,17 +54,17 @@ test_that('classification model', { set_engine("glm") %>% fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl) - cls_res <- predict(log_reg, lending_club %>% slice(1:7) %>% dplyr::select(-Class)) + cls_res <- predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class)) exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 7), levels = lvl)) expect_equal(cls_res, exp_cls_res) prb_res <- - predict(log_reg, lending_club %>% slice(1:7) %>% dplyr::select(-Class), type = "prob") + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "prob") exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 7), .pred_good = rep(NA_real_, 7)) expect_equal(prb_res, exp_prb_res) ci_res <- - predict(log_reg, lending_club %>% slice(1:7) %>% dplyr::select(-Class), type = "conf_int") + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int") exp_ci_res <- tibble( .pred_bad_lower = rep(NA_real_, 7), From eea2760fa4d3ed6bbd05be5207c87887cf8b7e14 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 18:10:25 -0500 Subject: [PATCH 08/15] fixed names for intervals --- R/logistic_reg_data.R | 12 ++++++------ R/predict_interval.R | 2 +- tests/testthat/test_logistic_reg.R | 8 ++++---- tests/testthat/test_logistic_reg_stan.R | 16 ++++++++-------- tests/testthat/test_predict_formats.R | 3 ++- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index c17c8245b..6fa1e0c94 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -86,8 +86,8 @@ logistic_reg_glm_data <- res_1$lo <- 1 - res_2$hi res_1$hi <- 1 - res_2$lo res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_", object$lvl, "_lower") - hi_nms <- paste0(".pred_", object$lvl, "_upper") + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) if (object$spec$method$confint$extras$std_error) @@ -225,8 +225,8 @@ logistic_reg_stan_data <- res_1$lo <- 1 - res_2$hi res_1$hi <- 1 - res_2$lo res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_", object$lvl, "_lower") - hi_nms <- paste0(".pred_", object$lvl, "_upper") + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) if (object$spec$method$confint$extras$std_error) @@ -263,8 +263,8 @@ logistic_reg_stan_data <- res_1$lo <- 1 - res_2$hi res_1$hi <- 1 - res_2$lo res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_", object$lvl, "_lower") - hi_nms <- paste0(".pred_", object$lvl, "_upper") + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) if (object$spec$method$predint$extras$std_error) diff --git a/R/predict_interval.R b/R/predict_interval.R index f71cb3da2..5160d93ef 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -60,7 +60,7 @@ failed_int <- function(n, lvl = NULL, nms = ".pred") { } else { res <- matrix(NA_real_, ncol = length(lvl) * 2, nrow = n) nms <- expand.grid(c("lower", "upper"), lvl) - nms <- paste(".pred", nms$Var2, nms$Var1, sep = "_") + nms <- paste(".pred", nms$Var1, nms$Var2, sep = "_") colnames(res) <- nms } as_tibble(res) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index 2eb74e658..7971e1c40 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -323,10 +323,10 @@ test_that('glm intervals', { level = 0.93, std_error = TRUE) - expect_equivalent(confidence_parsnip$.pred_good_lower, lower_glm) - expect_equivalent(confidence_parsnip$.pred_good_upper, upper_glm) - expect_equivalent(confidence_parsnip$.pred_bad_lower, 1 - upper_glm) - expect_equivalent(confidence_parsnip$.pred_bad_upper, 1 - lower_glm) + expect_equivalent(confidence_parsnip$.pred_lower_good, lower_glm) + expect_equivalent(confidence_parsnip$.pred_upper_good, upper_glm) + expect_equivalent(confidence_parsnip$.pred_lower_bad, 1 - upper_glm) + expect_equivalent(confidence_parsnip$.pred_upper_bad, 1 - lower_glm) expect_equivalent(confidence_parsnip$.std_error, pred_glm$se.fit) }) diff --git a/tests/testthat/test_logistic_reg_stan.R b/tests/testthat/test_logistic_reg_stan.R index ec260ffea..3373b3ce3 100644 --- a/tests/testthat/test_logistic_reg_stan.R +++ b/tests/testthat/test_logistic_reg_stan.R @@ -156,10 +156,10 @@ test_that('stan intervals', { stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) stan_std <- apply(stan_post, 2, sd) - expect_equivalent(confidence_parsnip$.pred_good_lower, stan_lower) - expect_equivalent(confidence_parsnip$.pred_good_upper, stan_upper) - expect_equivalent(confidence_parsnip$.pred_bad_lower, 1 - stan_upper) - expect_equivalent(confidence_parsnip$.pred_bad_upper, 1 - stan_lower) + expect_equivalent(confidence_parsnip$.pred_lower_good, stan_lower) + expect_equivalent(confidence_parsnip$.pred_upper_good, stan_upper) + expect_equivalent(confidence_parsnip$.pred_lower_bad, 1 - stan_upper) + expect_equivalent(confidence_parsnip$.pred_upper_bad, 1 - stan_lower) expect_equivalent(confidence_parsnip$.std_error, stan_std) stan_pred_post <- @@ -170,10 +170,10 @@ test_that('stan intervals', { stan_pred_upper <- apply(stan_pred_post, 2, quantile, prob = 0.965) stan_pred_std <- apply(stan_pred_post, 2, sd) - expect_equivalent(prediction_parsnip$.pred_good_lower, stan_pred_lower) - expect_equivalent(prediction_parsnip$.pred_good_upper, stan_pred_upper) - expect_equivalent(prediction_parsnip$.pred_bad_lower, 1 - stan_pred_upper) - expect_equivalent(prediction_parsnip$.pred_bad_upper, 1 - stan_pred_lower) + expect_equivalent(prediction_parsnip$.pred_lower_good, stan_pred_lower) + expect_equivalent(prediction_parsnip$.pred_upper_good, stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_lower_bad, 1 - stan_pred_upper) + expect_equivalent(prediction_parsnip$.pred_upper_bad, 1 - stan_pred_lower) expect_equivalent(prediction_parsnip$.std_error, stan_pred_std, tolerance = 0.1) }) diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index 0ea615c71..51bcc8f91 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -66,10 +66,11 @@ test_that('bad predict args', { lm_model <- linear_reg() %>% set_engine("lm") %>% - fit(mpg ~ ., data = mtcars) + fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32)) pred_cars <- mtcars %>% + dplyr::slice(1:10) %>% dplyr::select(-mpg) expect_error(predict(lm_model, pred_cars, yes = "no")) From 0d8a87dcab68e72848424043973abde867a88584 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 18:39:17 -0500 Subject: [PATCH 09/15] fixed test case --- tests/testthat/test_failed_models.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test_failed_models.R b/tests/testthat/test_failed_models.R index da94394af..116231cb3 100644 --- a/tests/testthat/test_failed_models.R +++ b/tests/testthat/test_failed_models.R @@ -67,10 +67,10 @@ test_that('classification model', { predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int") exp_ci_res <- tibble( - .pred_bad_lower = rep(NA_real_, 7), - .pred_bad_upper = rep(NA_real_, 7), - .pred_good_lower = rep(NA_real_, 7), - .pred_good_upper = rep(NA_real_, 7) + .pred_lower_bad = rep(NA_real_, 7), + .pred_upper_bad = rep(NA_real_, 7), + .pred_lower_good = rep(NA_real_, 7), + .pred_upper_good = rep(NA_real_, 7) ) expect_equal(ci_res, exp_ci_res) }) From b440a391605c38a8cfb99d1c96c58b061da21ed3 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 18:55:46 -0500 Subject: [PATCH 10/15] changed failed model predictions to n = 1 --- R/predict_class.R | 4 ++-- R/predict_classprob.R | 4 ++-- R/predict_interval.R | 6 +++--- R/predict_numeric.R | 4 ++-- tests/testthat/test_failed_models.R | 16 ++++++++-------- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/R/predict_class.R b/R/predict_class.R index e422a9d7f..926846e49 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -17,7 +17,7 @@ predict_class.model_fit <- function(object, new_data, ...) { stop("No class prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_class(n = nrow(new_data), lvl = object$lvl)) + return(failed_class(lvl = object$lvl)) } new_data <- prepare_data(object, new_data) @@ -58,7 +58,7 @@ predict_class <- function(object, ...) # Some `predict()` helpers for failed models: -failed_class <- function(n, lvl) { +failed_class <- function(n = 1, lvl) { res <- rep(NA_character_, n) res <- factor(res, levels = lvl) res diff --git a/R/predict_classprob.R b/R/predict_classprob.R index d0cf628fa..7eab48c39 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -14,7 +14,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) { stop("No class probability module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_classprob(n = nrow(new_data), lvl = object$lvl)) + return(failed_classprob(lvl = object$lvl)) } new_data <- prepare_data(object, new_data) @@ -55,7 +55,7 @@ predict_classprob <- function(object, ...) # Some `predict()` helpers for failed models: -failed_classprob <- function(n, lvl) { +failed_classprob <- function(n = 1, lvl) { res <- matrix(NA_real_, nrow = n, ncol = length(lvl)) colnames(res) <- lvl as_tibble(res) diff --git a/R/predict_interval.R b/R/predict_interval.R index 5160d93ef..fc9a8bf2d 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -15,7 +15,7 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error "engine.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_int(n = nrow(new_data), lvl = object$lvl)) + return(failed_int(lvl = object$lvl)) } new_data <- prepare_data(object, new_data) @@ -52,7 +52,7 @@ predict_confint <- function(object, ...) # Some `predict()` helpers for failed models: -failed_int <- function(n, lvl = NULL, nms = ".pred") { +failed_int <- function(n = 1, lvl = NULL, nms = ".pred") { # TODO figure out multivariate models if (is.null(lvl)) { res <- matrix(NA_real_, nrow = n, ncol = length(nms) * 2) @@ -81,7 +81,7 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error "engine.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_int(n = nrow(new_data), lvl = object$lvl)) + return(failed_int(lvl = object$lvl)) } new_data <- prepare_data(object, new_data) diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 873a0cf36..4f5191a24 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -16,7 +16,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) { if (inherits(object$fit, "try-error")) { # TODO handle multivariate cases - return(failed_numeric(n = nrow(new_data))) + return(failed_numeric()) } new_data <- prepare_data(object, new_data) @@ -56,7 +56,7 @@ predict_numeric <- function(object, ...) # Some `predict()` helpers for failed models: -failed_numeric <- function(n, nms = ".pred") { +failed_numeric <- function(n = 1, nms = ".pred") { res <- matrix(NA_real_, ncol = length(nms), nrow = n) if (length(nms) > 1) { colnames(res) <- nms diff --git a/tests/testthat/test_failed_models.R b/tests/testthat/test_failed_models.R index 116231cb3..8e09568c2 100644 --- a/tests/testthat/test_failed_models.R +++ b/tests/testthat/test_failed_models.R @@ -35,9 +35,9 @@ test_that('numeric model', { fit(Sepal.Length ~ ., data = iris_bad, control = ctrl) num_res <- predict(lm_mod, iris_bad[1:11, -1]) - expect_equal(num_res, tibble(.pred = rep(NA_real_, 11))) + expect_equal(num_res, tibble(.pred = rep(NA_real_, 1))) - exp_int_res <- tibble(.pred_lower = rep(NA_real_, 11), .pred_upper = rep(NA_real_, 11)) + exp_int_res <- tibble(.pred_lower = rep(NA_real_, 1), .pred_upper = rep(NA_real_, 1)) ci_res <- predict(lm_mod, iris_bad[1:11, -1], type = "conf_int") expect_equal(ci_res, exp_int_res) @@ -55,22 +55,22 @@ test_that('classification model', { fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl) cls_res <- predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class)) - exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 7), levels = lvl)) + exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 1), levels = lvl)) expect_equal(cls_res, exp_cls_res) prb_res <- predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "prob") - exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 7), .pred_good = rep(NA_real_, 7)) + exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 1), .pred_good = rep(NA_real_, 1)) expect_equal(prb_res, exp_prb_res) ci_res <- predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int") exp_ci_res <- tibble( - .pred_lower_bad = rep(NA_real_, 7), - .pred_upper_bad = rep(NA_real_, 7), - .pred_lower_good = rep(NA_real_, 7), - .pred_upper_good = rep(NA_real_, 7) + .pred_lower_bad = rep(NA_real_, 1), + .pred_upper_bad = rep(NA_real_, 1), + .pred_lower_good = rep(NA_real_, 1), + .pred_upper_good = rep(NA_real_, 1) ) expect_equal(ci_res, exp_ci_res) }) From 0ea300d7163e5444c442fd68817abb84f3aaa7ca Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 20:42:44 -0500 Subject: [PATCH 11/15] travis test to see if glmnet is installed --- tests/testthat/test_linear_reg_glmnet.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index 6c58182b2..ce594fb28 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -1,7 +1,7 @@ library(testthat) library(parsnip) library(rlang) - +library(glmnet) # ------------------------------------------------------------------------------ context("linear regression execution with glmnet") @@ -198,9 +198,9 @@ test_that('submodel prediction', { mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], penalty = .1) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred"]], unname(pred_glmn[,1])) - + expect_error( - multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), + multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), "Did you mean" ) }) From aba34b623c0e16fb1843b1b16d556453175d7bbf Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 26 Feb 2019 22:16:30 -0500 Subject: [PATCH 12/15] initial refactoring of glmnet prediction code --- R/linear_reg.R | 78 ++++++++++++++--- R/logistic_reg.R | 97 ++++++++++++++------- R/multinom_reg.R | 93 +++++++++++--------- tests/testthat/test_linear_reg_glmnet.R | 47 ++++++++++ tests/testthat/test_multinom_reg_glmnet.R | 4 +- tests/testthat/test_nearest_neighbor_kknn.R | 2 + 6 files changed, 239 insertions(+), 82 deletions(-) diff --git a/R/linear_reg.R b/R/linear_reg.R index 456f4bc9c..3e0e68513 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -63,7 +63,7 @@ #' \pkg{spark} #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} -#' +#' #' \pkg{keras} #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} @@ -216,12 +216,66 @@ organize_glmnet_pred <- function(x, object) { # ------------------------------------------------------------------------------ +# For `predict` methods that use `glmnet`, we have specific methods. +# Only one value of the penalty should be allowed when called by `predict()`: + +check_penalty <- function(penalty = NULL, object, multi = FALSE) { + + if (is.null(penalty)) { + penalty <- object$fit$lambda + } + + # when using `predict()`, allow for a single lambda + if (!multi) { + if (length(penalty) != 1) + stop("`penalty` should be a single numeric value. ", + "`multi_predict()` can be used to get multiple predictions ", + "per row of data.", call. = FALSE) + } + + if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) + stop("The glmnet model was fit with a single penalty value of ", + object$fit$lambda, ". Predicting with a value of ", + penalty, " will give incorrect results from `glmnet()`.", + call. = FALSE) + + penalty +} + +# ------------------------------------------------------------------------------ +# glmnet call stack for linear regression using `predict` when object has +# classes "_elnet" and "model_fit": +# +# predict() +# predict._elnet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_numeric() +# predict_numeric._elnet() +# predict_numeric.model_fit() +# predict.elnet() + + +# glmnet call stack for linear regression using `multi_predict` when object has +# classes "_elnet" and "model_fit": +# +# multi_predict() +# multi_predict._elnet(penalty = NULL) +# predict._elnet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._elnet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.elnet() + + #' @export predict._elnet <- - function(object, new_data, type = NULL, opts = list(), ...) { + function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + + object$spec$args$penalty <- check_penalty(penalty, object, multi) + object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } @@ -230,7 +284,7 @@ predict._elnet <- predict_numeric._elnet <- function(object, new_data, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + object$spec <- eval_args(object$spec) predict_numeric.model_fit(object, new_data = new_data, ...) } @@ -239,8 +293,9 @@ predict_numeric._elnet <- function(object, new_data, ...) { predict_raw._elnet <- function(object, new_data, opts = list(), ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + object$spec <- eval_args(object$spec) + opts$s <- object$spec$args$penalty predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -251,14 +306,17 @@ multi_predict._elnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - + dots <- list(...) - if (is.null(penalty)) - penalty <- object$fit$lambda - dots$s <- penalty object$spec <- eval_args(object$spec) - pred <- predict(object, new_data = new_data, type = "raw", opts = dots) + + if (is.null(penalty)) { + penalty <- object$fit$lambda + } + + pred <- predict._elnet(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) param_key <- tibble(group = colnames(pred), penalty = penalty) pred <- as_tibble(pred) pred$.row <- 1:nrow(pred) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 10a4df1db..a5da7b148 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -235,41 +235,41 @@ organize_glmnet_prob <- function(x, object) { } # ------------------------------------------------------------------------------ +# glmnet call stack for linear regression using `predict` when object has +# classes "_lognet" and "model_fit" (for class predictions): +# +# predict() +# predict._lognet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_class() +# predict_class._lognet() +# predict_class.model_fit() +# predict.lognet() + + +# glmnet call stack for linear regression using `multi_predict` when object has +# classes "_lognet" and "model_fit" (for class predictions): +# +# multi_predict() +# multi_predict._lognet(penalty = NULL) +# predict._lognet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._lognet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.lognet() -#' @export -predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) - stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) -} - -#' @export -predict_class._lognet <- function (object, new_data, ...) { - if (any(names(enquos(...)) == "newdata")) - stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} +# ------------------------------------------------------------------------------ #' @export -predict_classprob._lognet <- function (object, new_data, ...) { +predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} - -#' @export -predict_raw._lognet <- function (object, new_data, opts = list(), ...) { - if (any(names(enquos(...)) == "newdata")) - stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec$args$penalty <- check_penalty(penalty, object, multi) object$spec <- eval_args(object$spec) - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) + predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } @@ -281,15 +281,18 @@ multi_predict._lognet <- if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is_quosure(penalty)) + penalty <- eval_tidy(penalty) + dots <- list(...) if (is.null(penalty)) - penalty <- object$fit$lambda + penalty <- eval_tidy(object$fit$lambda) dots$s <- penalty if (is.null(type)) type <- "class" - if (!(type %in% c("class", "prob", "link"))) { - stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE) + if (!(type %in% c("class", "prob", "link", "raw"))) { + stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE) } if (type == "prob") dots$type <- "response" @@ -297,7 +300,7 @@ multi_predict._lognet <- dots$type <- type object$spec <- eval_args(object$spec) - pred <- predict(object, new_data = new_data, type = "raw", opts = dots) + pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots) param_key <- tibble(group = colnames(pred), penalty = penalty) pred <- as_tibble(pred) pred$.row <- 1:nrow(pred) @@ -321,6 +324,38 @@ multi_predict._lognet <- tibble(.pred = pred) } + + + + +#' @export +predict_class._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + object$spec <- eval_args(object$spec) + predict_class.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_classprob._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + object$spec <- eval_args(object$spec) + predict_classprob.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_raw._lognet <- function (object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + + object$spec <- eval_args(object$spec) + predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) +} + + # ------------------------------------------------------------------------------ #' @importFrom utils globalVariables diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 33fc37918..79154796c 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -188,54 +188,46 @@ organize_multnet_prob <- function(x, object) { } # ------------------------------------------------------------------------------ +# glmnet call stack for linear regression using `predict` when object has +# classes "_multnet" and "model_fit" (for class predictions): +# +# predict() +# predict._multnet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_class() +# predict_class._multnet() +# predict.multnet() + + +# glmnet call stack for linear regression using `multi_predict` when object has +# classes "_multnet" and "model_fit" (for class predictions): +# +# multi_predict() +# multi_predict._multnet(penalty = NULL) +# predict._multnet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._multnet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.multnet() -#' @export -predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) { - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) -} - -#' @export -predict_class._lognet <- function (object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} - -#' @export -predict_classprob._multnet <- function (object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} - -#' @export -predict_raw._multnet <- function (object, new_data, opts = list(), ...) { - object$spec <- eval_args(object$spec) - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) -} - +# ------------------------------------------------------------------------------ #' @export predict._multnet <- - function(object, new_data, type = NULL, opts = list(), penalty = NULL, ...) { - dots <- list(...) - if (is.null(penalty)) - penalty <- object$fit$lambda + function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) { + + object$spec$args$penalty <- check_penalty(penalty, object, multi) - if (length(penalty) != 1) - stop("`penalty` should be a single numeric value. ", - "`multi_predict()` can be used to get multiple predictions ", - "per row of data.", call. = FALSE) object$spec <- eval_args(object$spec) res <- predict.model_fit( object = object, new_data = new_data, type = type, - opts = opts, - penalty = penalty + opts = opts ) - res -} - + res + } #' @importFrom dplyr full_join as_tibble arrange #' @importFrom tidyr gather @@ -255,8 +247,8 @@ multi_predict._multnet <- if (is.null(type)) type <- "class" - if (!(type %in% c("class", "prob", "link"))) { - stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE) + if (!(type %in% c("class", "prob", "link", "raw"))) { + stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE) } if (type == "prob") dots$type <- "response" @@ -296,6 +288,29 @@ multi_predict._multnet <- tibble(.pred = pred) } +#' @export +predict_class._multnet <- function (object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_class.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_classprob._multnet <- function (object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_classprob.model_fit(object, new_data = new_data, ...) +} + +#' @export +predict_raw._multnet <- function (object, new_data, opts = list(), ...) { + object$spec <- eval_args(object$spec) + predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) +} + + + +# ------------------------------------------------------------------------------ + +# This checks as a pre-processor in the model data object check_glmnet_lambda <- function(dat, object) { if (length(object$fit$lambda) > 1) stop( diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index ce594fb28..7c64ae3fe 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -2,6 +2,7 @@ library(testthat) library(parsnip) library(rlang) library(glmnet) + # ------------------------------------------------------------------------------ context("linear regression execution with glmnet") @@ -203,5 +204,51 @@ test_that('submodel prediction', { multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), "Did you mean" ) + + reg_fit <- + linear_reg(penalty = c(0, 0.01, 0.1)) %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) + + + pred_glmn_all <- + predict(reg_fit$fit, as.matrix(mtcars[1:2, -1])) %>% + as.data.frame() %>% + stack() %>% + dplyr::arrange(ind) + + + mp_res_all <- + multi_predict(reg_fit, new_data = mtcars[1:2, -1]) %>% + tidyr::unnest() + + expect_equal(sort(mp_res_all$.pred), sort(pred_glmn_all$values)) + +}) + + +test_that('error traps', { + + skip_if_not_installed("glmnet") + + expect_error( + linear_reg(penalty = .1) %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ], penalty = .2) + ) + expect_error( + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ], penalty = 0:1) + ) + expect_error( + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ]) + ) + }) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index 4043bfc21..af6999434 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -58,7 +58,7 @@ test_that('glmnet prediction, one lambda', { uni_pred <- factor(uni_pred[,1], levels = levels(iris$Species)) uni_pred <- unname(uni_pred) - expect_equal(uni_pred, parsnip:::predict_class(xy_fit, iris[rows, 1:4])) + expect_equal(uni_pred, parsnip:::predict_class.model_fit(xy_fit, iris[rows, 1:4])) expect_equal(uni_pred, predict(xy_fit, iris[rows, 1:4], type = "class")$.pred_class) res_form <- fit( @@ -77,7 +77,7 @@ test_that('glmnet prediction, one lambda', { s = res_form$spec$args$penalty, type = "class") form_pred <- factor(form_pred[,1], levels = levels(iris$Species)) - expect_equal(form_pred, parsnip:::predict_class(res_form, iris[rows, c("Sepal.Width", "Petal.Width")])) + expect_equal(form_pred, parsnip:::predict_class.model_fit(res_form, iris[rows, c("Sepal.Width", "Petal.Width")])) expect_equal(form_pred, predict(res_form, iris[rows, c("Sepal.Width", "Petal.Width")], type = "class")$.pred_class) }) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 52474f692..cc483a156 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -20,6 +20,8 @@ quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) test_that('kknn execution', { skip_if_not_installed("kknn") + library(kknn) + # continuous # expect no error From 7fc1cf47b4add5c97c419431c6971992141a00b6 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 27 Feb 2019 06:26:18 -0500 Subject: [PATCH 13/15] removed glmnet load --- tests/testthat/test_linear_reg_glmnet.R | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index 7c64ae3fe..f03fcb38a 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -1,7 +1,6 @@ library(testthat) library(parsnip) library(rlang) -library(glmnet) # ------------------------------------------------------------------------------ From 9f7a25a68745c350d9d4602dc4258627c59249d4 Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 27 Feb 2019 10:32:19 -0500 Subject: [PATCH 14/15] re-wrote the glmnet notes --- R/linear_reg.R | 18 +++++++++--------- R/logistic_reg.R | 16 ++++++++-------- R/multinom_reg.R | 16 ++++++++-------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/R/linear_reg.R b/R/linear_reg.R index 3e0e68513..91ed2d125 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -69,14 +69,14 @@ #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. -#' This can have an effect on the model object results. When using -#' the `predict()` method in these cases, the return object type -#' depends on the value of `penalty`. If a single value is -#' given, the results will be a simple numeric vector. When -#' multiple values or no values for `penalty` are used in -#' `linear_reg()`, the `predict()` method will return a data frame with -#' columns `values` and `lambda`. +#' multiple values (or no values) to the `penalty` argument. This +#' can have an effect on the model object results. When using the +#' `predict()` method in these cases, the return value depends on +#' the value of `penalty`. When using `predict()`, only a single +#' value of the penalty can be used. When predicting on multiple +#' penalties, the `multi_predict()` function can be used. It +#' returns a tibble with a list column called `.pred` that contains +#' a tibble with all of the penalty results. #' #' For prediction, the `stan` engine can compute posterior #' intervals analogous to confidence and prediction intervals. In @@ -130,7 +130,7 @@ print.linear_reg <- function(x, ...) { cat("Linear Regression Model Specification (", x$mode, ")\n\n", sep = "") model_printer(x, ...) - if(!is.null(x$method$fit$args)) { + if (!is.null(x$method$fit$args)) { cat("Model fit template:\n") print(show_call(x)) } diff --git a/R/logistic_reg.R b/R/logistic_reg.R index a5da7b148..1c5a76d0b 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -67,14 +67,14 @@ #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. -#' This can have an effect on the model object results. When using -#' the `predict()` method in these cases, the return object type -#' depends on the value of `penalty`. If a single value is -#' given, the results will be a simple numeric vector. When -#' multiple values or no values for `penalty` are used in -#' `logistic_reg()`, the `predict()` method will return a data frame with -#' columns `values` and `lambda`. +#' multiple values (or no values) to the `penalty` argument. This +#' can have an effect on the model object results. When using the +#' `predict()` method in these cases, the return value depends on +#' the value of `penalty`. When using `predict()`, only a single +#' value of the penalty can be used. When predicting on multiple +#' penalties, the `multi_predict()` function can be used. It +#' returns a tibble with a list column called `.pred` that contains +#' a tibble with all of the penalty results. #' #' For prediction, the `stan` engine can compute posterior #' intervals analogous to confidence and prediction intervals. In diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 79154796c..8bc0deed8 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -58,14 +58,14 @@ #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass -#' multiple values (or no values) to the `penalty` argument. -#' This can have an effect on the model object results. When using -#' the `predict()` method in these cases, the return object type -#' depends on the value of `penalty`. If a single value is -#' given, the results will be a simple numeric vector. When -#' multiple values or no values for `penalty` are used in -#' `multinom_reg()`, the `predict()` method will return a data frame with -#' columns `values` and `lambda`. +#' multiple values (or no values) to the `penalty` argument. This +#' can have an effect on the model object results. When using the +#' `predict()` method in these cases, the return value depends on +#' the value of `penalty`. When using `predict()`, only a single +#' value of the penalty can be used. When predicting on multiple +#' penalties, the `multi_predict()` function can be used. It +#' returns a tibble with a list column called `.pred` that contains +#' a tibble with all of the penalty results. #' #' @note For models created using the spark engine, there are #' several differences to consider. First, only the formula From d8c7484e165faef664a79de2bcd6b048468cb8fe Mon Sep 17 00:00:00 2001 From: topepo Date: Wed, 27 Feb 2019 12:23:17 -0500 Subject: [PATCH 15/15] un-do return values for failed models for issue #123 --- NAMESPACE | 1 + R/predict.R | 20 ++++++------ R/predict_class.R | 13 ++------ R/predict_classprob.R | 15 ++------- R/predict_interval.R | 24 +++----------- R/predict_numeric.R | 22 ++----------- R/predict_quantile.R | 23 ++++++++------ R/predict_raw.R | 3 +- man/linear_reg.Rd | 16 +++++----- man/logistic_reg.Rd | 16 +++++----- man/multinom_reg.Rd | 16 +++++----- tests/testthat/test_failed_models.R | 49 ++++++++++++++--------------- 12 files changed, 85 insertions(+), 133 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index b57fd5df0..c4d345c5b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -17,6 +17,7 @@ S3method(predict,model_fit) S3method(predict,model_spec) S3method(predict,nullmodel) S3method(predict_class,"_lognet") +S3method(predict_class,"_multnet") S3method(predict_classprob,"_lognet") S3method(predict_classprob,"_multnet") S3method(predict_numeric,"_elnet") diff --git a/R/predict.R b/R/predict.R index 596de4b9d..d0dc1f735 100644 --- a/R/predict.R +++ b/R/predict.R @@ -112,6 +112,11 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) if (any(names(the_dots) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + other_args <- c("level", "std_error", "quantile") # "time" for survival probs later is_pred_arg <- names(the_dots) %in% other_args if (any(!is_pred_arg)) { @@ -242,8 +247,13 @@ prepare_data <- function(object, new_data) { #' multiple rows per sub-model. #' @keywords internal #' @export -multi_predict <- function(object, ...) +multi_predict <- function(object, ...) { + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } UseMethod("multi_predict") +} #' @keywords internal #' @export @@ -256,11 +266,3 @@ multi_predict.default <- function(object, ...) predict.model_spec <- function(object, ...) { stop("You must use `fit()` on your model specification before you can use `predict()`.", call. = FALSE) } - - -failed_class <- function(n, lvl) { - tibble(.pred = rep(NA_real_, n)) -} - - - diff --git a/R/predict_class.R b/R/predict_class.R index 926846e49..0e292a8b3 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -17,7 +17,8 @@ predict_class.model_fit <- function(object, new_data, ...) { stop("No class prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_class(lvl = object$lvl)) + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) } new_data <- prepare_data(object, new_data) @@ -54,13 +55,3 @@ predict_class.model_fit <- function(object, new_data, ...) { predict_class <- function(object, ...) UseMethod("predict_class") -# ------------------------------------------------------------------------------ - -# Some `predict()` helpers for failed models: - -failed_class <- function(n = 1, lvl) { - res <- rep(NA_character_, n) - res <- factor(res, levels = lvl) - res -} - diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 7eab48c39..d902e7735 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -14,7 +14,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) { stop("No class probability module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_classprob(lvl = object$lvl)) + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) } new_data <- prepare_data(object, new_data) @@ -49,15 +50,3 @@ predict_classprob.model_fit <- function(object, new_data, ...) { # @inheritParams predict.model_fit predict_classprob <- function(object, ...) UseMethod("predict_classprob") - - -# ------------------------------------------------------------------------------ - -# Some `predict()` helpers for failed models: - -failed_classprob <- function(n = 1, lvl) { - res <- matrix(NA_real_, nrow = n, ncol = length(lvl)) - colnames(res) <- lvl - as_tibble(res) -} - diff --git a/R/predict_interval.R b/R/predict_interval.R index fc9a8bf2d..f38838e00 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -15,7 +15,8 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error "engine.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_int(lvl = object$lvl)) + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) } new_data <- prepare_data(object, new_data) @@ -50,24 +51,6 @@ predict_confint <- function(object, ...) # ------------------------------------------------------------------------------ -# Some `predict()` helpers for failed models: - -failed_int <- function(n = 1, lvl = NULL, nms = ".pred") { - # TODO figure out multivariate models - if (is.null(lvl)) { - res <- matrix(NA_real_, nrow = n, ncol = length(nms) * 2) - colnames(res) <- c(".pred_lower", ".pred_upper") - } else { - res <- matrix(NA_real_, ncol = length(lvl) * 2, nrow = n) - nms <- expand.grid(c("lower", "upper"), lvl) - nms <- paste(".pred", nms$Var1, nms$Var2, sep = "_") - colnames(res) <- nms - } - as_tibble(res) -} - -# ------------------------------------------------------------------------------ - # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit @@ -81,7 +64,8 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error "engine.", call. = FALSE) if (inherits(object$fit, "try-error")) { - return(failed_int(lvl = object$lvl)) + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) } new_data <- prepare_data(object, new_data) diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 4f5191a24..3a509546b 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -15,8 +15,8 @@ predict_numeric.model_fit <- function(object, new_data, ...) { stop("No prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { - # TODO handle multivariate cases - return(failed_numeric()) + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) } new_data <- prepare_data(object, new_data) @@ -51,21 +51,3 @@ predict_numeric.model_fit <- function(object, new_data, ...) { # @inheritParams predict_numeric.model_fit predict_numeric <- function(object, ...) UseMethod("predict_numeric") - -# ------------------------------------------------------------------------------ - -# Some `predict()` helpers for failed models: - -failed_numeric <- function(n = 1, nms = ".pred") { - res <- matrix(NA_real_, ncol = length(nms), nrow = n) - if (length(nms) > 1) { - colnames(res) <- nms - res <- as_tibble(res) - } else { - res <- res[,1] - } - res -} - - - diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 17c9786de..698ddb4c8 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,35 +1,40 @@ # @keywords internal # @rdname other_predict -# @param quant A vector of numbers between 0 and 1 for the quantile being -# predicted. +# @param quant A vector of numbers between 0 and 1 for the quantile being +# predicted. # @inheritParams predict.model_fit # @method predict_quantile model_fit # @export predict_quantile.model_fit # @export predict_quantile.model_fit <- function (object, new_data, quantile = (1:9)/10, ...) { - + if (is.null(object$spec$method$quantile)) stop("No quantile prediction method defined for this ", "engine.", call. = FALSE) - + + if (inherits(object$fit, "try-error")) { + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) + } + new_data <- prepare_data(object, new_data) - + # preprocess data if (!is.null(object$spec$method$quantile$pre)) new_data <- object$spec$method$quantile$pre(new_data, object) - + # Pass some extra arguments to be used in post-processor object$spec$method$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$quantile) - + res <- eval_tidy(pred_call) - + # post-process the predictions if(!is.null(object$spec$method$quantile$post)) { res <- object$spec$method$quantile$post(res, object) } - + res } diff --git a/R/predict_raw.R b/R/predict_raw.R index 20a970d56..315c9dd0a 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -18,7 +18,8 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { stop("No raw prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { - stop("Model fit failed; cannot make predictions.") + warning("Model fit failed; cannot make predictions.", call. = FALSE) + return(NULL) } new_data <- prepare_data(object, new_data) diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index c4539fd37..c009ba8a3 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -106,14 +106,14 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. -This can have an effect on the model object results. When using -the \code{predict()} method in these cases, the return object type -depends on the value of \code{penalty}. If a single value is -given, the results will be a simple numeric vector. When -multiple values or no values for \code{penalty} are used in -\code{linear_reg()}, the \code{predict()} method will return a data frame with -columns \code{values} and \code{lambda}. +multiple values (or no values) to the \code{penalty} argument. This +can have an effect on the model object results. When using the +\code{predict()} method in these cases, the return value depends on +the value of \code{penalty}. When using \code{predict()}, only a single +value of the penalty can be used. When predicting on multiple +penalties, the \code{multi_predict()} function can be used. It +returns a tibble with a list column called \code{.pred} that contains +a tibble with all of the penalty results. For prediction, the \code{stan} engine can compute posterior intervals analogous to confidence and prediction intervals. In diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index e7ac89673..0d36b3b73 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -104,14 +104,14 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. -This can have an effect on the model object results. When using -the \code{predict()} method in these cases, the return object type -depends on the value of \code{penalty}. If a single value is -given, the results will be a simple numeric vector. When -multiple values or no values for \code{penalty} are used in -\code{logistic_reg()}, the \code{predict()} method will return a data frame with -columns \code{values} and \code{lambda}. +multiple values (or no values) to the \code{penalty} argument. This +can have an effect on the model object results. When using the +\code{predict()} method in these cases, the return value depends on +the value of \code{penalty}. When using \code{predict()}, only a single +value of the penalty can be used. When predicting on multiple +penalties, the \code{multi_predict()} function can be used. It +returns a tibble with a list column called \code{.pred} that contains +a tibble with all of the penalty results. For prediction, the \code{stan} engine can compute posterior intervals analogous to confidence and prediction intervals. In diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index ebcc03f49..f85f7a99c 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -95,14 +95,14 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} When using \code{glmnet} models, there is the option to pass -multiple values (or no values) to the \code{penalty} argument. -This can have an effect on the model object results. When using -the \code{predict()} method in these cases, the return object type -depends on the value of \code{penalty}. If a single value is -given, the results will be a simple numeric vector. When -multiple values or no values for \code{penalty} are used in -\code{multinom_reg()}, the \code{predict()} method will return a data frame with -columns \code{values} and \code{lambda}. +multiple values (or no values) to the \code{penalty} argument. This +can have an effect on the model object results. When using the +\code{predict()} method in these cases, the return value depends on +the value of \code{penalty}. When using \code{predict()}, only a single +value of the penalty can be used. When predicting on multiple +penalties, the \code{multi_predict()} function can be used. It +returns a tibble with a list column called \code{.pred} that contains +a tibble with all of the penalty results. } \examples{ diff --git a/tests/testthat/test_failed_models.R b/tests/testthat/test_failed_models.R index 8e09568c2..39fd4a88a 100644 --- a/tests/testthat/test_failed_models.R +++ b/tests/testthat/test_failed_models.R @@ -34,15 +34,14 @@ test_that('numeric model', { set_engine("lm") %>% fit(Sepal.Length ~ ., data = iris_bad, control = ctrl) - num_res <- predict(lm_mod, iris_bad[1:11, -1]) - expect_equal(num_res, tibble(.pred = rep(NA_real_, 1))) + expect_warning(num_res <- predict(lm_mod, iris_bad[1:11, -1])) + expect_equal(num_res, NULL) - exp_int_res <- tibble(.pred_lower = rep(NA_real_, 1), .pred_upper = rep(NA_real_, 1)) - ci_res <- predict(lm_mod, iris_bad[1:11, -1], type = "conf_int") - expect_equal(ci_res, exp_int_res) + expect_warning(ci_res <- predict(lm_mod, iris_bad[1:11, -1], type = "conf_int")) + expect_equal(ci_res, NULL) - pi_res <- predict(lm_mod, iris_bad[1:11, -1], type = "pred_int") - expect_equal(pi_res, exp_int_res) + expect_warning(pi_res <- predict(lm_mod, iris_bad[1:11, -1], type = "pred_int")) + expect_equal(pi_res, NULL) }) @@ -54,24 +53,22 @@ test_that('classification model', { set_engine("glm") %>% fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl) - cls_res <- predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class)) - exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 1), levels = lvl)) - expect_equal(cls_res, exp_cls_res) - - prb_res <- - predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "prob") - exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 1), .pred_good = rep(NA_real_, 1)) - expect_equal(prb_res, exp_prb_res) - - ci_res <- - predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int") - exp_ci_res <- - tibble( - .pred_lower_bad = rep(NA_real_, 1), - .pred_upper_bad = rep(NA_real_, 1), - .pred_lower_good = rep(NA_real_, 1), - .pred_upper_good = rep(NA_real_, 1) - ) - expect_equal(ci_res, exp_ci_res) + expect_warning( + cls_res <- + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class)) + ) + expect_equal(cls_res, NULL) + + expect_warning( + prb_res <- + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "prob") + ) + expect_equal(prb_res, NULL) + + expect_warning( + ci_res <- + predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int") + ) + expect_equal(ci_res, NULL) })