diff --git a/DESCRIPTION b/DESCRIPTION index 983d8f28a..0bf18317e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.1.0.9002 +Version: 1.1.0.9003 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/R/survival-censoring-weights.R b/R/survival-censoring-weights.R index 0254f0431..28b935c0c 100644 --- a/R/survival-censoring-weights.R +++ b/R/survival-censoring-weights.R @@ -63,6 +63,11 @@ trunc_probs <- function(probs, trunc = 0.01) { } .check_censor_model <- function(x) { + if (x$spec$mode != "censored regression") { + cli::cli_abort( + "The model needs to be for mode 'censored regression', not for mode '{x$spec$mode}'." + ) + } nms <- names(x) if (!any(nms == "censor_probs")) { rlang::abort("Please refit the model with parsnip version 1.0.4 or greater.") @@ -245,7 +250,9 @@ add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10 num_times <- vctrs::list_sizes(.pred) y <- vctrs::list_unchop(.pred) y$surv_obj <- vctrs::vec_rep_each(surv_obj, times = num_times) + names(y)[names(y) == ".time"] <- ".eval_time" # Temporary + # Compute the actual time of evaluation y$.weight_time <- graf_weight_time_vec(y$surv_obj, y$.eval_time, eps = eps) # Compute the corresponding probability of being censored @@ -253,6 +260,7 @@ add_graf_weights_vec <- function(object, .pred, surv_obj, trunc = 0.05, eps = 10 y$.pred_censored <- trunc_probs(y$.pred_censored, trunc = trunc) # Invert the probabilities to create weights y$.weight_censored = 1 / y$.pred_censored + # Convert back the list column format y$surv_obj <- NULL vctrs::vec_chop(y, sizes = num_times)