Skip to content

Commit

Permalink
rand_forest() - partykit
Browse files Browse the repository at this point in the history
`time` -> `eval_time`
  • Loading branch information
hfrick committed Mar 22, 2023
1 parent 6e4bd08 commit d654a7d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion R/rand_forest-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ make_rand_forest_partykit <- function() {
args = list(
object = rlang::expr(object$fit),
new_data = rlang::expr(new_data),
time = rlang::expr(time)
eval_time = rlang::expr(eval_time)
)
)
)
Expand Down
16 changes: 8 additions & 8 deletions tests/testthat/test-rand_forest-partykit.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ test_that("survival predictions", {

expect_error(
predict(f_fit, lung, type = "survival"),
"When using 'type' values of 'survival' or 'hazard' are given"
"When using `type` values of 'survival' or 'hazard', a numeric vector"
)

f_pred <- predict(f_fit, lung, type = "survival", time = 100:200)
f_pred <- predict(f_fit, lung, type = "survival", eval_time = 100:200)

expect_s3_class(f_pred, "tbl_df")
expect_equal(names(f_pred), ".pred")
Expand All @@ -93,7 +93,7 @@ test_that("survival predictions", {
101
)
cf_names <-
c(".time", ".pred_survival")
c(".eval_time", ".pred_survival")
expect_true(
all(purrr::map_lgl(
f_pred$.pred,
Expand All @@ -102,12 +102,12 @@ test_that("survival predictions", {
)

expect_equal(
tidyr::unnest(f_pred, cols = c(.pred))$.time,
tidyr::unnest(f_pred, cols = c(.pred))$.eval_time,
rep(100:200, nrow(lung))
)


f_pred <- predict(f_fit, lung[1, ], type = "survival", time = 306)
f_pred <- predict(f_fit, lung[1, ], type = "survival", eval_time = 306)
new_km <- predict(exp_f_fit, lung[1, ], type = "prob")[[1]]

expect_equal(
Expand All @@ -117,7 +117,7 @@ test_that("survival predictions", {

# with NA in one of the predictors
set.seed(1234)
f_pred <- predict(f_fit, lung[14, ], type = "survival", time = 71)
f_pred <- predict(f_fit, lung[14, ], type = "survival", eval_time = 71)
set.seed(1234)
new_km <- predict(exp_f_fit, lung[14, ], type = "prob")[[1]]

Expand Down Expand Up @@ -162,14 +162,14 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "survival",
time = c(100, 200)
eval_time = c(100, 200)
)
set.seed(1)
xy_pred_survival <- predict(
xy_fit,
new_data = lung_pred,
type = "survival",
time = c(100, 200)
eval_time = c(100, 200)
)
expect_equal(f_pred_survival, xy_pred_survival)
})

0 comments on commit d654a7d

Please sign in to comment.