Skip to content

Commit

Permalink
rand_forest() - aorsf
Browse files Browse the repository at this point in the history
`time` -> `eval_time`
  • Loading branch information
hfrick committed Mar 22, 2023
1 parent d654a7d commit 14491d8
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 20 deletions.
24 changes: 17 additions & 7 deletions R/rand_forest-aorsf.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
#' Internal helper function for aorsf objects
#' @param object A model object from `aorsf::orsf()`.
#' @param new_data A data frame to be predicted.
#' @param time A vector of times to predict the survival probability.
#' @param eval_time A vector of times to predict the survival probability.
#' @param time Deprecated. A vector of times to predict the survival probability.
#' @return A tibble with a list column of nested tibbles.
#' @export
#' @keywords internal
#' @name aorsf_internal
#' @examples
#' library(aorsf)
#' aorsf <- orsf(na.omit(lung), Surv(time, status) ~ age + ph.ecog, n_tree = 10)
#' preds <- survival_prob_orsf(aorsf, lung[1:3, ], time = c(250, 100))
survival_prob_orsf <- function(object, new_data, time) {
#' preds <- survival_prob_orsf(aorsf, lung[1:3, ], eval_time = c(250, 100))
survival_prob_orsf <- function(object, new_data, eval_time, time = deprecated()) {
if (lifecycle::is_present(time)) {
lifecycle::deprecate_warn(
"0.1.1.9002",
"survival_prob_orsf(time)",
"survival_prob_orsf(eval_time)"
)
eval_time <- time
}

# This is not just a `post` hook in `set_pred()` because parsnip adds the
# argument `time` to the prediction call and `aorsf::predict.orsf_fit()`
# expects empty dots, i.e. no `time` argument.
# argument `eval_time` to the prediction call and `aorsf::predict.orsf_fit()`
# expects empty dots, i.e. no `eval_time` argument.

res <- predict(
object,
new_data = new_data,
pred_horizon = time,
pred_horizon = eval_time,
pred_type = "surv",
na_action = "pass",
boundary_checks = FALSE
)

res <- matrix_to_nested_tibbles_survival(res, time)
res <- matrix_to_nested_tibbles_survival(res, eval_time)

# return a tibble
tibble(.pred = res)
Expand Down
3 changes: 2 additions & 1 deletion R/rand_forest-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ make_rand_forest_aorsf <- function() {
func = c(pkg = "censored", fun = "survival_prob_orsf"),
args = list(
object = rlang::expr(object$fit),
new_data = rlang::expr(new_data)
new_data = rlang::expr(new_data),
eval_time = rlang::expr(eval_time)
)
)
)
Expand Down
8 changes: 5 additions & 3 deletions man/aorsf_internal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions tests/testthat/test-rand_forest-aorsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ test_that("survival predictions", {

expect_error(
predict(f_fit, lung_orsf, type = "survival"),
"When using 'type' values of 'survival' or 'hazard' are given"
"When using `type` values of 'survival' or 'hazard', a numeric vector"
)

f_pred <- predict(f_fit, lung, type = "survival", time = c(100, 500, 1200))
f_pred <- predict(f_fit, lung, type = "survival", eval_time = c(100, 500, 1200))

expect_s3_class(f_pred, "tbl_df")
expect_equal(names(f_pred), ".pred")
Expand All @@ -66,7 +66,7 @@ test_that("survival predictions", {
)

cf_names <-
c(".time", ".pred_survival")
c(".eval_time", ".pred_survival")

expect_true(
all(purrr::map_lgl(
Expand All @@ -77,7 +77,7 @@ test_that("survival predictions", {

# correct prediction times in object
expect_equal(
tidyr::unnest(f_pred, cols = c(.pred))$.time,
tidyr::unnest(f_pred, cols = c(.pred))$.eval_time,
rep(c(100, 500, 1200), nrow(lung))
)

Expand All @@ -103,7 +103,7 @@ test_that("survival predictions", {
)

# equal predictions with multiple times and one testing observation
f_pred <- predict(f_fit, lung[1, ], type = "survival", time = c(100, 500, 1200))
f_pred <- predict(f_fit, lung[1, ], type = "survival", eval_time = c(100, 500, 1200))

new_km <- predict(
exp_f_fit,
Expand All @@ -124,7 +124,7 @@ test_that("survival predictions", {
)

# equal predictions with one time and multiple testing observation
f_pred <- predict(f_fit, lung, type = "survival", time = 306)
f_pred <- predict(f_fit, lung, type = "survival", eval_time = 306)

new_km <- predict(
exp_f_fit,
Expand All @@ -144,7 +144,7 @@ test_that("survival predictions", {
)

# equal predictions with one time and one testing observation
f_pred <- predict(f_fit, lung[1, ], type = "survival", time = 306)
f_pred <- predict(f_fit, lung[1, ], type = "survival", eval_time = 306)

new_km <- predict(
exp_f_fit,
Expand Down Expand Up @@ -195,13 +195,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "survival",
time = c(100, 200)
eval_time = c(100, 200)
)
xy_pred_survival <- predict(
xy_fit,
new_data = lung_pred,
type = "survival",
time = c(100, 200)
eval_time = c(100, 200)
)
expect_equal(f_pred_survival, xy_pred_survival)
})
Expand Down

0 comments on commit 14491d8

Please sign in to comment.