diff --git a/R/rand_forest-aorsf.R b/R/rand_forest-aorsf.R index cf29708a..28a26345 100644 --- a/R/rand_forest-aorsf.R +++ b/R/rand_forest-aorsf.R @@ -1,7 +1,8 @@ #' Internal helper function for aorsf objects #' @param object A model object from `aorsf::orsf()`. #' @param new_data A data frame to be predicted. -#' @param time A vector of times to predict the survival probability. +#' @param eval_time A vector of times to predict the survival probability. +#' @param time Deprecated. A vector of times to predict the survival probability. #' @return A tibble with a list column of nested tibbles. #' @export #' @keywords internal @@ -9,22 +10,31 @@ #' @examples #' library(aorsf) #' aorsf <- orsf(na.omit(lung), Surv(time, status) ~ age + ph.ecog, n_tree = 10) -#' preds <- survival_prob_orsf(aorsf, lung[1:3, ], time = c(250, 100)) -survival_prob_orsf <- function(object, new_data, time) { +#' preds <- survival_prob_orsf(aorsf, lung[1:3, ], eval_time = c(250, 100)) +survival_prob_orsf <- function(object, new_data, eval_time, time = deprecated()) { + if (lifecycle::is_present(time)) { + lifecycle::deprecate_warn( + "0.1.1.9002", + "survival_prob_orsf(time)", + "survival_prob_orsf(eval_time)" + ) + eval_time <- time + } + # This is not just a `post` hook in `set_pred()` because parsnip adds the - # argument `time` to the prediction call and `aorsf::predict.orsf_fit()` - # expects empty dots, i.e. no `time` argument. + # argument `eval_time` to the prediction call and `aorsf::predict.orsf_fit()` + # expects empty dots, i.e. no `eval_time` argument. res <- predict( object, new_data = new_data, - pred_horizon = time, + pred_horizon = eval_time, pred_type = "surv", na_action = "pass", boundary_checks = FALSE ) - res <- matrix_to_nested_tibbles_survival(res, time) + res <- matrix_to_nested_tibbles_survival(res, eval_time) # return a tibble tibble(.pred = res) diff --git a/R/rand_forest-data.R b/R/rand_forest-data.R index 0da93408..16000176 100644 --- a/R/rand_forest-data.R +++ b/R/rand_forest-data.R @@ -197,7 +197,8 @@ make_rand_forest_aorsf <- function() { func = c(pkg = "censored", fun = "survival_prob_orsf"), args = list( object = rlang::expr(object$fit), - new_data = rlang::expr(new_data) + new_data = rlang::expr(new_data), + eval_time = rlang::expr(eval_time) ) ) ) diff --git a/man/aorsf_internal.Rd b/man/aorsf_internal.Rd index b3e6b0ad..9b09298d 100644 --- a/man/aorsf_internal.Rd +++ b/man/aorsf_internal.Rd @@ -5,14 +5,16 @@ \alias{survival_prob_orsf} \title{Internal helper function for aorsf objects} \usage{ -survival_prob_orsf(object, new_data, time) +survival_prob_orsf(object, new_data, eval_time, time = deprecated()) } \arguments{ \item{object}{A model object from \code{aorsf::orsf()}.} \item{new_data}{A data frame to be predicted.} -\item{time}{A vector of times to predict the survival probability.} +\item{eval_time}{A vector of times to predict the survival probability.} + +\item{time}{Deprecated. A vector of times to predict the survival probability.} } \value{ A tibble with a list column of nested tibbles. @@ -23,6 +25,6 @@ Internal helper function for aorsf objects \examples{ library(aorsf) aorsf <- orsf(na.omit(lung), Surv(time, status) ~ age + ph.ecog, n_tree = 10) -preds <- survival_prob_orsf(aorsf, lung[1:3, ], time = c(250, 100)) +preds <- survival_prob_orsf(aorsf, lung[1:3, ], eval_time = c(250, 100)) } \keyword{internal} diff --git a/tests/testthat/test-rand_forest-aorsf.R b/tests/testthat/test-rand_forest-aorsf.R index 6ef7fa76..ce832c09 100644 --- a/tests/testthat/test-rand_forest-aorsf.R +++ b/tests/testthat/test-rand_forest-aorsf.R @@ -52,10 +52,10 @@ test_that("survival predictions", { expect_error( predict(f_fit, lung_orsf, type = "survival"), - "When using 'type' values of 'survival' or 'hazard' are given" + "When using `type` values of 'survival' or 'hazard', a numeric vector" ) - f_pred <- predict(f_fit, lung, type = "survival", time = c(100, 500, 1200)) + f_pred <- predict(f_fit, lung, type = "survival", eval_time = c(100, 500, 1200)) expect_s3_class(f_pred, "tbl_df") expect_equal(names(f_pred), ".pred") @@ -66,7 +66,7 @@ test_that("survival predictions", { ) cf_names <- - c(".time", ".pred_survival") + c(".eval_time", ".pred_survival") expect_true( all(purrr::map_lgl( @@ -77,7 +77,7 @@ test_that("survival predictions", { # correct prediction times in object expect_equal( - tidyr::unnest(f_pred, cols = c(.pred))$.time, + tidyr::unnest(f_pred, cols = c(.pred))$.eval_time, rep(c(100, 500, 1200), nrow(lung)) ) @@ -103,7 +103,7 @@ test_that("survival predictions", { ) # equal predictions with multiple times and one testing observation - f_pred <- predict(f_fit, lung[1, ], type = "survival", time = c(100, 500, 1200)) + f_pred <- predict(f_fit, lung[1, ], type = "survival", eval_time = c(100, 500, 1200)) new_km <- predict( exp_f_fit, @@ -124,7 +124,7 @@ test_that("survival predictions", { ) # equal predictions with one time and multiple testing observation - f_pred <- predict(f_fit, lung, type = "survival", time = 306) + f_pred <- predict(f_fit, lung, type = "survival", eval_time = 306) new_km <- predict( exp_f_fit, @@ -144,7 +144,7 @@ test_that("survival predictions", { ) # equal predictions with one time and one testing observation - f_pred <- predict(f_fit, lung[1, ], type = "survival", time = 306) + f_pred <- predict(f_fit, lung[1, ], type = "survival", eval_time = 306) new_km <- predict( exp_f_fit, @@ -195,13 +195,13 @@ test_that("`fix_xy()` works", { f_fit, new_data = lung_pred, type = "survival", - time = c(100, 200) + eval_time = c(100, 200) ) xy_pred_survival <- predict( xy_fit, new_data = lung_pred, type = "survival", - time = c(100, 200) + eval_time = c(100, 200) ) expect_equal(f_pred_survival, xy_pred_survival) })