Skip to content

add argument interval for survival/quantile predictions #615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

## Other Changes

* When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates).
* When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates).

* The fit time is only calculated when the `verbosity` argument of `control_parsnip()` is 2L or greater. Also, the call to `system.time()` now uses `gcFirst = FALSE`. (#611)
* The fit time is only calculated when the `verbosity` argument of `control_parsnip()` is 2L or greater. Also, the call to `system.time()` now uses `gcFirst = FALSE`. (#611)

* Argument `interval` was added for prediction: For types "survival" and "quantile", estimates for the confidence or prediction interval can be added if available (#615).

# parsnip 0.1.7

Expand Down
3 changes: 2 additions & 1 deletion R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ update.bart <-
}


#' Developer functions for predictions via BART models
#' @export
#' @keywords internal
#' @name bart-internal
#' @inherit predict.model_fit
#' @inheritParams predict.model_fit
#' @param obj A parsnip object.
#' @param ci Confidence (TRUE) or prediction interval (FALSE)
#' @param level Confidence level.
Expand Down
20 changes: 11 additions & 9 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
#' `parsnip` related options that can be passed, depending on the
#' value of `type`. Possible arguments are:
#' \itemize{
#' \item `level`: for `type`s of "conf_int" and "pred_int" this
#' is the parameter for the tail area of the intervals
#' \item `interval`: for `type`s of "survival" and "quantile", should
#' interval estimates be added, if available? Options are `"none"`
#' and `"confidence"`.
#' \item `level`: for `type`s of "conf_int", "pred_int", and "survival"
#' this is the parameter for the tail area of the intervals
#' (e.g. confidence level for confidence intervals).
#' Default value is 0.95.
#' \item `std_error`: add the standard error of fit or prediction (on
Expand Down Expand Up @@ -82,12 +85,10 @@
#' For censored regression:
#'
#' * `type = "time"` produces a column `.pred_time`.
#' * `type = "hazard"` results in a column `.pred_hazard`.
#' * `type = "survival"` results in a list column containing tibbles with a
#' `.pred_survival` column.
#'
#' For the last two types, the results are a nested tibble with an overall
#' column called `.pred` with sub-tibbles with the above format.
#' * `type = "hazard"` results in a list column `.pred` containing tibbles
#' with a column `.pred_hazard`.
#' * `type = "survival"` results in a list column `.pred` containing tibbles
#' with a `.pred_survival` column.
#'
#' In the case of Spark-based models, since table columns cannot
#' contain dots, the same convention is used except 1) no dots
Expand All @@ -98,6 +99,7 @@
#' `predict()` function will return the same structure as above but
#' filled with missing values. This does not currently work for
#' multivariate models.
#'
#' @examples
#' library(dplyr)
#'
Expand Down Expand Up @@ -309,7 +311,7 @@ check_pred_type_dots <- function(object, type, ...) {

# ----------------------------------------------------------------------------

other_args <- c("level", "std_error", "quantile", "time", "increasing")
other_args <- c("interval", "level", "std_error", "quantile", "time", "increasing")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
Expand Down
10 changes: 7 additions & 3 deletions R/predict_quantile.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#' @keywords internal
#' @rdname other_predict
#' @param quant A vector of numbers between 0 and 1 for the quantile being
#' @param quantile A vector of numbers between 0 and 1 for the quantile being
#' predicted.
#' @inheritParams predict.model_fit
#' @method predict_quantile model_fit
#' @export predict_quantile.model_fit
#' @export
predict_quantile.model_fit <-
function(object, new_data, quantile = (1:9)/10, ...) {
predict_quantile.model_fit <- function(object,
new_data,
quantile = (1:9)/10,
interval = "none",
level = 0.95,
...) {

check_spec_pred_type(object, "quantile")

Expand Down
2 changes: 1 addition & 1 deletion R/predict_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @export predict_survival.model_fit
#' @export
predict_survival.model_fit <-
function(object, new_data, time, ...) {
function(object, new_data, time, interval = "none", level = 0.95, ...) {

check_spec_pred_type(object, "survival")

Expand Down
109 changes: 2 additions & 107 deletions man/bart-internal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 15 additions & 5 deletions man/other_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions man/predict.model_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.