diff --git a/NEWS.md b/NEWS.md index 195780ae9..990cdc821 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,9 +14,11 @@ ## Other Changes - * When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates). +* When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates). - * The fit time is only calculated when the `verbosity` argument of `control_parsnip()` is 2L or greater. Also, the call to `system.time()` now uses `gcFirst = FALSE`. (#611) +* The fit time is only calculated when the `verbosity` argument of `control_parsnip()` is 2L or greater. Also, the call to `system.time()` now uses `gcFirst = FALSE`. (#611) + +* Argument `interval` was added for prediction: For types "survival" and "quantile", estimates for the confidence or prediction interval can be added if available (#615). # parsnip 0.1.7 diff --git a/R/bart.R b/R/bart.R index d527d67b2..06377be09 100644 --- a/R/bart.R +++ b/R/bart.R @@ -162,10 +162,11 @@ update.bart <- } +#' Developer functions for predictions via BART models #' @export #' @keywords internal #' @name bart-internal -#' @inherit predict.model_fit +#' @inheritParams predict.model_fit #' @param obj A parsnip object. #' @param ci Confidence (TRUE) or prediction interval (FALSE) #' @param level Confidence level. diff --git a/R/predict.R b/R/predict.R index 18611acee..abb366499 100644 --- a/R/predict.R +++ b/R/predict.R @@ -19,8 +19,11 @@ #' `parsnip` related options that can be passed, depending on the #' value of `type`. Possible arguments are: #' \itemize{ -#' \item `level`: for `type`s of "conf_int" and "pred_int" this -#' is the parameter for the tail area of the intervals +#' \item `interval`: for `type`s of "survival" and "quantile", should +#' interval estimates be added, if available? Options are `"none"` +#' and `"confidence"`. +#' \item `level`: for `type`s of "conf_int", "pred_int", and "survival" +#' this is the parameter for the tail area of the intervals #' (e.g. confidence level for confidence intervals). #' Default value is 0.95. #' \item `std_error`: add the standard error of fit or prediction (on @@ -82,12 +85,10 @@ #' For censored regression: #' #' * `type = "time"` produces a column `.pred_time`. -#' * `type = "hazard"` results in a column `.pred_hazard`. -#' * `type = "survival"` results in a list column containing tibbles with a -#' `.pred_survival` column. -#' -#' For the last two types, the results are a nested tibble with an overall -#' column called `.pred` with sub-tibbles with the above format. +#' * `type = "hazard"` results in a list column `.pred` containing tibbles +#' with a column `.pred_hazard`. +#' * `type = "survival"` results in a list column `.pred` containing tibbles +#' with a `.pred_survival` column. #' #' In the case of Spark-based models, since table columns cannot #' contain dots, the same convention is used except 1) no dots @@ -98,6 +99,7 @@ #' `predict()` function will return the same structure as above but #' filled with missing values. This does not currently work for #' multivariate models. +#' #' @examples #' library(dplyr) #' @@ -309,7 +311,7 @@ check_pred_type_dots <- function(object, type, ...) { # ---------------------------------------------------------------------------- - other_args <- c("level", "std_error", "quantile", "time", "increasing") + other_args <- c("interval", "level", "std_error", "quantile", "time", "increasing") is_pred_arg <- names(the_dots) %in% other_args if (any(!is_pred_arg)) { bad_args <- names(the_dots)[!is_pred_arg] diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 7965069e6..c2817e48b 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,13 +1,17 @@ #' @keywords internal #' @rdname other_predict -#' @param quant A vector of numbers between 0 and 1 for the quantile being +#' @param quantile A vector of numbers between 0 and 1 for the quantile being #' predicted. #' @inheritParams predict.model_fit #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export -predict_quantile.model_fit <- - function(object, new_data, quantile = (1:9)/10, ...) { +predict_quantile.model_fit <- function(object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ...) { check_spec_pred_type(object, "quantile") diff --git a/R/predict_survival.R b/R/predict_survival.R index eb87a08ad..fe58c62cb 100644 --- a/R/predict_survival.R +++ b/R/predict_survival.R @@ -5,7 +5,7 @@ #' @export predict_survival.model_fit #' @export predict_survival.model_fit <- - function(object, new_data, time, ...) { + function(object, new_data, time, interval = "none", level = 0.95, ...) { check_spec_pred_type(object, "survival") diff --git a/man/bart-internal.Rd b/man/bart-internal.Rd index 577d05dda..690f715f7 100644 --- a/man/bart-internal.Rd +++ b/man/bart-internal.Rd @@ -4,7 +4,7 @@ \alias{bart-internal} \alias{bartMachine_interval_calc} \alias{dbart_predict_calc} -\title{Model predictions} +\title{Developer functions for predictions via BART models} \usage{ bartMachine_interval_calc(new_data, obj, ci = TRUE, level = 0.95) @@ -26,112 +26,7 @@ appropriate value based on the model's mode.} \item{std_err}{Attach column for standard error of prediction or not.} } -\value{ -With the exception of \code{type = "raw"}, the results of -\code{predict.model_fit()} will be a tibble as many rows in the output -as there are rows in \code{new_data} and the column names will be -predictable. - -For numeric results with a single outcome, the tibble will have -a \code{.pred} column and \code{.pred_Yname} for multivariate results. - -For hard class predictions, the column is named \code{.pred_class} -and, when \code{type = "prob"}, the columns are \code{.pred_classlevel}. - -\code{type = "conf_int"} and \code{type = "pred_int"} return tibbles with -columns \code{.pred_lower} and \code{.pred_upper} with an attribute for -the confidence level. In the case where intervals can be -produces for class probabilities (or other non-scalar outputs), -the columns will be named \code{.pred_lower_classlevel} and so on. - -Quantile predictions return a tibble with a column \code{.pred}, which is -a list-column. Each list element contains a tibble with columns -\code{.pred} and \code{.quantile} (and perhaps other columns). - -Using \code{type = "raw"} with \code{predict.model_fit()} will return -the unadulterated results of the prediction function. - -For censored regression: -\itemize{ -\item \code{type = "time"} produces a column \code{.pred_time}. -\item \code{type = "hazard"} results in a column \code{.pred_hazard}. -\item \code{type = "survival"} results in a list column containing tibbles with a -\code{.pred_survival} column. -} - -For the last two types, the results are a nested tibble with an overall -column called \code{.pred} with sub-tibbles with the above format. - -In the case of Spark-based models, since table columns cannot -contain dots, the same convention is used except 1) no dots -appear in names and 2) vectors are never returned but -type-specific prediction functions. - -When the model fit failed and the error was captured, the -\code{predict()} function will return the same structure as above but -filled with missing values. This does not currently work for -multivariate models. -} \description{ -Apply a model to create different types of predictions. -\code{predict()} can be used for all types of models and uses the -"type" argument for more specificity. -} -\details{ -If "type" is not supplied to \code{predict()}, then a choice -is made: -\itemize{ -\item \code{type = "numeric"} for regression models, -\item \code{type = "class"} for classification, and -\item \code{type = "time"} for censored regression. -} - -\code{predict()} is designed to provide a tidy result (see "Value" -section below) in a tibble output format. -\subsection{Interval predictions}{ - -When using \code{type = "conf_int"} and \code{type = "pred_int"}, the options -\code{level} and \code{std_error} can be used. The latter is a logical for an -extra column of standard error values (if available). -} - -\subsection{Censored regression predictions}{ - -For censored regression, a numeric vector for \code{time} is required when -survival or hazard probabilities are requested. Also, when -\code{type = "linear_pred"}, censored regression models will by default be -formatted such that the linear predictor \emph{increases} with time. This may -have the opposite sign as what the underlying model's \code{predict()} method -produces. Set \code{increasing = FALSE} to suppress this behavior. -} -} -\examples{ -library(dplyr) - -lm_model <- - linear_reg() \%>\% - set_engine("lm") \%>\% - fit(mpg ~ ., data = mtcars \%>\% dplyr::slice(11:32)) - -pred_cars <- - mtcars \%>\% - dplyr::slice(1:10) \%>\% - dplyr::select(-mpg) - -predict(lm_model, pred_cars) - -predict( - lm_model, - pred_cars, - type = "conf_int", - level = 0.90 -) - -predict( - lm_model, - pred_cars, - type = "raw", - opts = list(type = "terms") -) +Developer functions for predictions via BART models } \keyword{internal} diff --git a/man/other_predict.Rd b/man/other_predict.Rd index dd02edb8b..ee49f70d1 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -35,9 +35,16 @@ predict_linear_pred(object, ...) predict_numeric(object, ...) -\method{predict_quantile}{model_fit}(object, new_data, quantile = (1:9)/10, ...) +\method{predict_quantile}{model_fit}( + object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ... +) -\method{predict_survival}{model_fit}(object, new_data, time, ...) +\method{predict_survival}{model_fit}(object, new_data, time, interval = "none", level = 0.95, ...) predict_survival(object, ...) @@ -55,8 +62,11 @@ function cannot be passed here (see \code{opts}). There are some \code{parsnip} related options that can be passed, depending on the value of \code{type}. Possible arguments are: \itemize{ -\item \code{level}: for \code{type}s of "conf_int" and "pred_int" this -is the parameter for the tail area of the intervals +\item \code{interval}: for \code{type}s of "survival" and "quantile", should +interval estimates be added, if available? Options are \code{"none"} +and \code{"confidence"}. +\item \code{level}: for \code{type}s of "conf_int", "pred_int", and "survival" +this is the parameter for the tail area of the intervals (e.g. confidence level for confidence intervals). Default value is 0.95. \item \code{std_error}: add the standard error of fit or prediction (on @@ -73,7 +83,7 @@ interval estimates.} \item{std_error}{A single logical for whether the standard error should be returned (assuming that the model can compute it).} -\item{quant}{A vector of numbers between 0 and 1 for the quantile being +\item{quantile}{A vector of numbers between 0 and 1 for the quantile being predicted.} } \description{ diff --git a/man/predict.model_fit.Rd b/man/predict.model_fit.Rd index ae645fa5e..4632bb205 100644 --- a/man/predict.model_fit.Rd +++ b/man/predict.model_fit.Rd @@ -32,8 +32,11 @@ function cannot be passed here (see \code{opts}). There are some \code{parsnip} related options that can be passed, depending on the value of \code{type}. Possible arguments are: \itemize{ -\item \code{level}: for \code{type}s of "conf_int" and "pred_int" this -is the parameter for the tail area of the intervals +\item \code{interval}: for \code{type}s of "survival" and "quantile", should +interval estimates be added, if available? Options are \code{"none"} +and \code{"confidence"}. +\item \code{level}: for \code{type}s of "conf_int", "pred_int", and "survival" +this is the parameter for the tail area of the intervals (e.g. confidence level for confidence intervals). Default value is 0.95. \item \code{std_error}: add the standard error of fit or prediction (on @@ -72,14 +75,12 @@ the unadulterated results of the prediction function. For censored regression: \itemize{ \item \code{type = "time"} produces a column \code{.pred_time}. -\item \code{type = "hazard"} results in a column \code{.pred_hazard}. -\item \code{type = "survival"} results in a list column containing tibbles with a -\code{.pred_survival} column. +\item \code{type = "hazard"} results in a list column \code{.pred} containing tibbles +with a column \code{.pred_hazard}. +\item \code{type = "survival"} results in a list column \code{.pred} containing tibbles +with a \code{.pred_survival} column. } -For the last two types, the results are a nested tibble with an overall -column called \code{.pred} with sub-tibbles with the above format. - In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots appear in names and 2) vectors are never returned but