Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding new prediction types for Survnip #396

Merged
merged 6 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: parsnip
Version: 0.1.5.9000
Version: 0.1.5.9001
Title: A Common API to Modeling and Analysis Functions
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
Authors@R: c(
Expand Down Expand Up @@ -32,7 +32,7 @@ Imports:
prettyunits,
vctrs (>= 0.2.0)
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.1.9000
RoxygenNote: 7.1.1.9001
Suggests:
testthat,
knitr,
Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_hazard,model_fit)
S3method(predict_linear_pred,model_fit)
S3method(predict_numeric,"_elnet")
S3method(predict_numeric,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
S3method(predict_survival,model_fit)
S3method(predict_time,model_fit)
S3method(print,boost_tree)
S3method(print,control_parsnip)
S3method(print,decision_tree)
Expand Down Expand Up @@ -156,11 +160,17 @@ export(predict.model_fit)
export(predict_class.model_fit)
export(predict_classprob.model_fit)
export(predict_confint.model_fit)
export(predict_hazard.model_fit)
export(predict_linear_pred)
export(predict_linear_pred.model_fit)
export(predict_numeric)
export(predict_numeric.model_fit)
export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(predict_survival.model_fit)
export(predict_time)
export(predict_time.model_fit)
export(prepare_data)
export(rand_forest)
export(repair_call)
Expand Down
3 changes: 2 additions & 1 deletion R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown")
# ------------------------------------------------------------------------------

pred_types <-
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile")
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
"time", "survival", "linear_pred", "hazard")

# ------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ check_empty_ellipse <- function (...) {
terms
}

all_modes <- c("classification", "regression")
all_modes <- c("classification", "regression", "censored regression")


deparserizer <- function(x, limit = options()$width - 10) {
Expand Down
199 changes: 162 additions & 37 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
#' @param object An object of class `model_fit`
#' @param new_data A rectangular data object, such as a data frame.
#' @param type A single character value or `NULL`. Possible values
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
#' based on the model's mode.
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time",
#' "hazard", "survival", or "raw". When `NULL`, `predict()` will choose an
#' appropriate value based on the model's mode.
#' @param opts A list of optional arguments to the underlying
#' predict function that will be used when `type = "raw"`. The
#' list should not include options for the model object or the
Expand All @@ -28,20 +28,32 @@
#' and "pred_int". Default value is `FALSE`.
#' \item `quantile`: the quantile(s) for quantile regression
#' (not implemented yet)
#' \item `time`: the time(s) for hazard probability estimates
#' (not implemented yet)
#' \item `.time`: the time(s) for hazard and survival probability estimates.
#' }
#' @details If "type" is not supplied to `predict()`, then a choice
#' is made (`type = "numeric"` for regression models and
#' `type = "class"` for classification).
#' is made:
#'
#' * `type = "numeric"` for regression models,
#' * `type = "class"` for classification, and
#' * `type = "time"` for censored regression.
#'
#' `predict()` is designed to provide a tidy result (see "Value"
#' section below) in a tibble output format.
#'
#' ## Interval predictions
#'
#' When using `type = "conf_int"` and `type = "pred_int"`, the options
#' `level` and `std_error` can be used. The latter is a logical for an
#' extra column of standard error values (if available).
#'
#' ## Censored regression predictions
#'
#' For censored regression, a numeric vector for `.time` is required when
#' survival or hazard probabilities are requested. Also, when
#' `type = "linear_pred"`, censored regression models will be formatted such
#' that the linear predictor _increases_ with time. This may have the opposite
#' sign as what the underlying model's `predict()` method produces.
#'
#' @return With the exception of `type = "raw"`, the results of
#' `predict.model_fit()` will be a tibble as many rows in the output
#' as there are rows in `new_data` and the column names will be
Expand All @@ -66,6 +78,15 @@
#' Using `type = "raw"` with `predict.model_fit()` will return
#' the unadulterated results of the prediction function.
#'
#' For censored regression:
#'
#' * `type = "time"` produces a column `.pred_time`.
#' * `type = "hazard"` results in a column `.pred_hazard`.
#' * `type = "survival"` results in a column `.pred_survival`.
#'
#' For the last two types, the results are a nested tibble with an overall
#' column called `.pred` with sub-tibbles with the above format.
#'
#' In the case of Spark-based models, since table columns cannot
#' contain dots, the same convention is used except 1) no dots
#' appear in names and 2) vectors are never returned but
Expand Down Expand Up @@ -108,10 +129,6 @@
#' @export predict.model_fit
#' @export
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
the_dots <- enquos(...)
if (any(names(the_dots) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
Expand All @@ -120,53 +137,54 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
check_installs(object$spec)
load_libs(object$spec, quiet = TRUE)

other_args <- c("level", "std_error", "quantile") # "time" for survival probs later
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
)
}

type <- check_pred_type(object, type)
if (type != "raw" && length(opts) > 0)
if (type != "raw" && length(opts) > 0) {
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
}
check_pred_type_dots(type, ...)

res <- switch(
type,
numeric = predict_numeric(object = object, new_data = new_data, ...),
class = predict_class(object = object, new_data = new_data, ...),
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
numeric = predict_numeric(object = object, new_data = new_data, ...),
class = predict_class(object = object, new_data = new_data, ...),
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
time = predict_time(object = object, new_data = new_data, ...),
survival = predict_survival(object = object, new_data = new_data, ...),
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
hazard = predict_hazard(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
rlang::abort(glue::glue("I don't know about type = '{type}'"))
)
if (!inherits(res, "tbl_spark")) {
res <- switch(
type,
numeric = format_num(res),
class = format_class(res),
prob = format_classprobs(res),
numeric = format_num(res),
class = format_class(res),
prob = format_classprobs(res),
time = format_time(res),
survival = format_survival(res),
hazard = format_hazard(res),
linear_pred = format_linear_pred(res),
res
)
}
res
}

surv_types <- c("time", "survival", "hazard")

#' @importFrom glue glue_collapse
check_pred_type <- function(object, type) {
check_pred_type <- function(object, type, ...) {
if (is.null(type)) {
type <-
switch(object$spec$mode,
regression = "numeric",
classification = "class",
rlang::abort("`type` should be 'regression' or 'classification'."))
"censored regression" = "time",
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
}
if (!(type %in% pred_types))
rlang::abort(
Expand All @@ -181,6 +199,10 @@ check_pred_type <- function(object, type) {
rlang::abort("For class predictions, the object should be a classification model.")
if (type == "prob" & object$spec$mode != "classification")
rlang::abort("For probability predictions, the object should be a classification model.")
if (type %in% surv_types & object$spec$mode != "censored regression")
rlang::abort("For event time predictions, the object should be a censored regression.")

# TODO check for ... options when not the correct type
type
}

Expand Down Expand Up @@ -216,6 +238,61 @@ format_classprobs <- function(x) {
x
}

format_time <- function(x) {
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_time = unname(x))
}

x
}

format_survival <- function(x) {
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_survival = unname(x))
}

x
}

format_linear_pred <- function(x) {
if (inherits(x, "tbl_spark"))
return(x)

if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_linear_pred = unname(x))
}

x
}

format_hazard <- function(x) {
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_hazard = unname(x))
}

x
}

make_pred_call <- function(x) {
if ("pkg" %in% names(x$func))
cl <-
Expand All @@ -226,6 +303,54 @@ make_pred_call <- function(x) {
cl
}

check_pred_type_dots <- function(type, ...) {
the_dots <- list(...)
nms <- names(the_dots)

# ----------------------------------------------------------------------------

if (any(names(the_dots) == "newdata")) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
}

# ----------------------------------------------------------------------------

other_args <- c("level", "std_error", "quantile", ".time")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
)
}

# ----------------------------------------------------------------------------
# places where .time should not be given
if (any(nms == ".time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
".time should only be passed to `predict()` when 'type' is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
}
# when .time should be passed
if (!any(nms == ".time") & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using 'type' values of 'survival' or 'hazard' are given,",
"a numeric vector '.time' should also be given."
)
)
}
invisible(TRUE)
}


#' Prepare data based on parsnip encoding information
#' @param object A parsnip model object
#' @param new_data A data frame
Expand Down