Skip to content

Commit

Permalink
Merge pull request #396 from EmilHvitfeldt/survnip-integration
Browse files Browse the repository at this point in the history
adding new prediction types for Survnip
  • Loading branch information
topepo committed Mar 1, 2021
2 parents 154c1ab + c0c0191 commit 963ba95
Show file tree
Hide file tree
Showing 12 changed files with 419 additions and 57 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: parsnip
Version: 0.1.5.9000
Version: 0.1.5.9001
Title: A Common API to Modeling and Analysis Functions
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
Authors@R: c(
Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_hazard,model_fit)
S3method(predict_linear_pred,model_fit)
S3method(predict_numeric,"_elnet")
S3method(predict_numeric,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
S3method(predict_survival,model_fit)
S3method(predict_time,model_fit)
S3method(print,boost_tree)
S3method(print,control_parsnip)
S3method(print,decision_tree)
Expand Down Expand Up @@ -156,11 +160,17 @@ export(predict.model_fit)
export(predict_class.model_fit)
export(predict_classprob.model_fit)
export(predict_confint.model_fit)
export(predict_hazard.model_fit)
export(predict_linear_pred)
export(predict_linear_pred.model_fit)
export(predict_numeric)
export(predict_numeric.model_fit)
export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(predict_survival.model_fit)
export(predict_time)
export(predict_time.model_fit)
export(prepare_data)
export(rand_forest)
export(repair_call)
Expand Down
3 changes: 2 additions & 1 deletion R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ parsnip$modes <- c("regression", "classification", "unknown")
# ------------------------------------------------------------------------------

pred_types <-
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile")
c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
"time", "survival", "linear_pred", "hazard")

# ------------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ check_empty_ellipse <- function (...) {
terms
}

all_modes <- c("classification", "regression")
all_modes <- c("classification", "regression", "censored regression")


deparserizer <- function(x, limit = options()$width - 10) {
Expand Down
199 changes: 162 additions & 37 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
#' @param object An object of class `model_fit`
#' @param new_data A rectangular data object, such as a data frame.
#' @param type A single character value or `NULL`. Possible values
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
#' or "raw". When `NULL`, `predict()` will choose an appropriate value
#' based on the model's mode.
#' are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time",
#' "hazard", "survival", or "raw". When `NULL`, `predict()` will choose an
#' appropriate value based on the model's mode.
#' @param opts A list of optional arguments to the underlying
#' predict function that will be used when `type = "raw"`. The
#' list should not include options for the model object or the
Expand All @@ -28,20 +28,32 @@
#' and "pred_int". Default value is `FALSE`.
#' \item `quantile`: the quantile(s) for quantile regression
#' (not implemented yet)
#' \item `time`: the time(s) for hazard probability estimates
#' (not implemented yet)
#' \item `.time`: the time(s) for hazard and survival probability estimates.
#' }
#' @details If "type" is not supplied to `predict()`, then a choice
#' is made (`type = "numeric"` for regression models and
#' `type = "class"` for classification).
#' is made:
#'
#' * `type = "numeric"` for regression models,
#' * `type = "class"` for classification, and
#' * `type = "time"` for censored regression.
#'
#' `predict()` is designed to provide a tidy result (see "Value"
#' section below) in a tibble output format.
#'
#' ## Interval predictions
#'
#' When using `type = "conf_int"` and `type = "pred_int"`, the options
#' `level` and `std_error` can be used. The latter is a logical for an
#' extra column of standard error values (if available).
#'
#' ## Censored regression predictions
#'
#' For censored regression, a numeric vector for `.time` is required when
#' survival or hazard probabilities are requested. Also, when
#' `type = "linear_pred"`, censored regression models will be formatted such
#' that the linear predictor _increases_ with time. This may have the opposite
#' sign as what the underlying model's `predict()` method produces.
#'
#' @return With the exception of `type = "raw"`, the results of
#' `predict.model_fit()` will be a tibble as many rows in the output
#' as there are rows in `new_data` and the column names will be
Expand All @@ -66,6 +78,15 @@
#' Using `type = "raw"` with `predict.model_fit()` will return
#' the unadulterated results of the prediction function.
#'
#' For censored regression:
#'
#' * `type = "time"` produces a column `.pred_time`.
#' * `type = "hazard"` results in a column `.pred_hazard`.
#' * `type = "survival"` results in a column `.pred_survival`.
#'
#' For the last two types, the results are a nested tibble with an overall
#' column called `.pred` with sub-tibbles with the above format.
#'
#' In the case of Spark-based models, since table columns cannot
#' contain dots, the same convention is used except 1) no dots
#' appear in names and 2) vectors are never returned but
Expand Down Expand Up @@ -108,10 +129,6 @@
#' @export predict.model_fit
#' @export
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
the_dots <- enquos(...)
if (any(names(the_dots) == "newdata"))
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
Expand All @@ -120,53 +137,54 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
check_installs(object$spec)
load_libs(object$spec, quiet = TRUE)

other_args <- c("level", "std_error", "quantile") # "time" for survival probs later
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
)
}

type <- check_pred_type(object, type)
if (type != "raw" && length(opts) > 0)
if (type != "raw" && length(opts) > 0) {
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
}
check_pred_type_dots(type, ...)

res <- switch(
type,
numeric = predict_numeric(object = object, new_data = new_data, ...),
class = predict_class(object = object, new_data = new_data, ...),
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
numeric = predict_numeric(object = object, new_data = new_data, ...),
class = predict_class(object = object, new_data = new_data, ...),
prob = predict_classprob(object = object, new_data = new_data, ...),
conf_int = predict_confint(object = object, new_data = new_data, ...),
pred_int = predict_predint(object = object, new_data = new_data, ...),
quantile = predict_quantile(object = object, new_data = new_data, ...),
time = predict_time(object = object, new_data = new_data, ...),
survival = predict_survival(object = object, new_data = new_data, ...),
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
hazard = predict_hazard(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
rlang::abort(glue::glue("I don't know about type = '{type}'"))
)
if (!inherits(res, "tbl_spark")) {
res <- switch(
type,
numeric = format_num(res),
class = format_class(res),
prob = format_classprobs(res),
numeric = format_num(res),
class = format_class(res),
prob = format_classprobs(res),
time = format_time(res),
survival = format_survival(res),
hazard = format_hazard(res),
linear_pred = format_linear_pred(res),
res
)
}
res
}

surv_types <- c("time", "survival", "hazard")

#' @importFrom glue glue_collapse
check_pred_type <- function(object, type) {
check_pred_type <- function(object, type, ...) {
if (is.null(type)) {
type <-
switch(object$spec$mode,
regression = "numeric",
classification = "class",
rlang::abort("`type` should be 'regression' or 'classification'."))
"censored regression" = "time",
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
}
if (!(type %in% pred_types))
rlang::abort(
Expand All @@ -181,6 +199,10 @@ check_pred_type <- function(object, type) {
rlang::abort("For class predictions, the object should be a classification model.")
if (type == "prob" & object$spec$mode != "classification")
rlang::abort("For probability predictions, the object should be a classification model.")
if (type %in% surv_types & object$spec$mode != "censored regression")
rlang::abort("For event time predictions, the object should be a censored regression.")

# TODO check for ... options when not the correct type
type
}

Expand Down Expand Up @@ -216,6 +238,61 @@ format_classprobs <- function(x) {
x
}

format_time <- function(x) {
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_time = unname(x))
}

x
}

format_survival <- function(x) {
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_survival = unname(x))
}

x
}

format_linear_pred <- function(x) {
if (inherits(x, "tbl_spark"))
return(x)

if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_linear_pred = unname(x))
}

x
}

format_hazard <- function(x) {
if (isTRUE(ncol(x) > 1) | is.data.frame(x)) {
x <- as_tibble(x, .name_repair = "minimal")
if (!any(grepl("^\\.time", names(x)))) {
names(x) <- paste0(".time_", names(x))
}
} else {
x <- tibble(.pred_hazard = unname(x))
}

x
}

make_pred_call <- function(x) {
if ("pkg" %in% names(x$func))
cl <-
Expand All @@ -226,6 +303,54 @@ make_pred_call <- function(x) {
cl
}

check_pred_type_dots <- function(type, ...) {
the_dots <- list(...)
nms <- names(the_dots)

# ----------------------------------------------------------------------------

if (any(names(the_dots) == "newdata")) {
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
}

# ----------------------------------------------------------------------------

other_args <- c("level", "std_error", "quantile", ".time")
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
)
}

# ----------------------------------------------------------------------------
# places where .time should not be given
if (any(nms == ".time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
".time should only be passed to `predict()` when 'type' is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
}
# when .time should be passed
if (!any(nms == ".time") & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using 'type' values of 'survival' or 'hazard' are given,",
"a numeric vector '.time' should also be given."
)
)
}
invisible(TRUE)
}


#' Prepare data based on parsnip encoding information
#' @param object A parsnip model object
#' @param new_data A data frame
Expand Down
43 changes: 43 additions & 0 deletions R/predict_hazard.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict.model_fit
#' @method predict_hazard model_fit
#' @export predict_hazard.model_fit
#' @export
predict_hazard.model_fit <-
function(object, new_data, .time, ...) {

if (is.null(object$spec$method$pred$hazard))
rlang::abort("No hazard prediction method defined for this engine.")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$hazard$pre))
new_data <- object$spec$method$pred$hazard$pre(new_data, object)

# Pass some extra arguments to be used in post-processor
object$spec$method$pred$hazard$args$.time <- .time
pred_call <- make_pred_call(object$spec$method$pred$hazard)

res <- eval_tidy(pred_call)

# post-process the predictions
if(!is.null(object$spec$method$pred$hazard$post)) {
res <- object$spec$method$pred$hazard$post(res, object)
}

res
}

# @export
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_hazard <- function (object, ...)
UseMethod("predict_hazard")

0 comments on commit 963ba95

Please sign in to comment.