diff --git a/DESCRIPTION b/DESCRIPTION index c0ceacd2d..f74ff58f8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.0.4.9003 +Version: 1.0.4.9004 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/NAMESPACE b/NAMESPACE index 21e02a22e..39e4b3bbb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,8 @@ # Generated by roxygen2: do not edit by hand +S3method(.censoring_weights_graf,default) +S3method(.censoring_weights_graf,model_fit) +S3method(.censoring_weights_graf,workflow) S3method(augment,model_fit) S3method(autoplot,glmnet) S3method(autoplot,model_fit) @@ -144,6 +147,7 @@ S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) export("%>%") +export(.censoring_weights_graf) export(.check_glmnet_penalty_fit) export(.check_glmnet_penalty_predict) export(.cols) diff --git a/R/ipcw.R b/R/ipcw.R index ae1ef7bce..a21c5d6bf 100644 --- a/R/ipcw.R +++ b/R/ipcw.R @@ -19,6 +19,9 @@ trunc_probs <- function(probs, trunc = 0.01) { } .filter_eval_time <- function(eval_time, fail = TRUE) { + if (!is.null(eval_time)) { + eval_time <- as.numeric(eval_time) + } # will still propagate nulls: eval_time <- eval_time[!is.na(eval_time)] eval_time <- unique(eval_time) @@ -32,3 +35,199 @@ trunc_probs <- function(probs, trunc = 0.01) { } eval_time } + +add_dot_row_to_weights <- function(dat, rows = NULL) { + if (is.null(rows)) { + dat <- add_rowindex(dat) + } else { + m <- length(rows) + n <- nrow(dat) + if (m != n) { + rlang::abort( + glue::glue( + "The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})" + ) + ) + } + dat$.row <- rows + } + dat +} + +.check_censor_model <- function(x) { + nms <- names(x) + if (!any(nms == "censor_probs")) { + rlang::abort("Please refit the model with parsnip version 1.0.4 or greater.") + } + invisible(NULL) +} + +# nocov start +# these are tested in extratests +# ------------------------------------------------------------------------------ +# Brier score helpers. Most of this is based off of Graf, E., Schmoor, C., +# Sauerbrei, W. and Schumacher, M. (1999), Assessment and comparison of +# prognostic classification schemes for survival data. _Statist. Med._, 18: +# 2529-2545. + +# We need to use the time of analysis to determine what time to use to evaluate +# the IPCWs. + +graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) { + event_time <- .extract_surv_time(surv_obj) + status <- .extract_surv_status(surv_obj) + is_event_before_t <- event_time <= eval_time & status == 1 + is_censored <- event_time > eval_time + + # Three possible contributions to the statistic from Graf 1999 + + # Censoring time before eval_time, no contribution (Graf category 3) + weight_time <- rep(NA_real_, length(event_time)) + + # A real event prior to eval_time (Graf category 1) + weight_time[is_event_before_t] <- event_time[is_event_before_t] - eps + + # Observed time greater than eval_time (Graf category 2) + weight_time[is_censored] <- eval_time - eps + + weight_time <- ifelse(weight_time < 0, 0, weight_time) + + res <- tibble::tibble(surv = surv_obj, weight_time = weight_time, eval_time) + add_dot_row_to_weights(res, rows) +} + +# ------------------------------------------------------------------------------ +#' Calculations for inverse probability of censoring weights (IPCW) +#' +#' The method of Graf _et al_ (1999) is used to compute weights at specific +#' evaluation times that can be used to help measure a model's time-dependent +#' performance (e.g. the time-dependent Brier score or the area under the ROC +#' curve). +#' @param data A data frame with a column containing a [survival::Surv()] object. +#' @param predictors Not currently used. A potential future slot for models with +#' informative censoring based on columns in `data`. +#' @param rows An optional integer vector with length equal to the number of +#' rows in `data` that is used to index the original data. The default is to +#' use a fresh index on data (i.e. `1:nrow(data)`). +#' @param eval_time A vector of finite, non-negative times at which to +#' compute the probability of censoring and the corresponding weights. +#' @param object A fitted parsnip model object or fitted workflow with a mode +#' of "censored regression". +#' @param trunc A potential lower bound for the probability of censoring to avoid +#' very large weight values. +#' @param eps A small value that is subtracted from the evaluation time when +#' computing the censoring probabilities. See Details below. +#' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the +#' probability of being censored just prior to the evaluation time), and +#' `.weight_cens` (the inverse probability of censoring weight). +#' @details +#' +#' A probability that the data are censored immediately prior to a specific +#' time is computed. To do this, we must determine what time to +#' make the prediction. There are two time values for each row of the data set: +#' the observed time (either censored or not) and the time that the model is +#' being evaluated at (e.g. the survival function prediction at some time point), +#' which is constant across rows. . +#' +#' From Graf _et al_ (1999) there are three cases: +#' +#' - If the observed time is a censoring time and that is before the +#' evaluation time, the data point should make no contribution to the +#' performance metric (their "category 3"). These values have a missing +#' value for their probability estimate (and also for their weight column). +#' +#' - If the observed time corresponds to an actual event, and that time is +#' prior to the evaluation time (category 1), the probability of being +#' censored is predicted at the observed time (minus an epsilon). +#' +#' - If the observed time is _after_ the evaluation time (category 2), regardless of +#' the status, the probability of being censored is predicted at the evaluation +#' time (minus an epsilon). +#' +#' The epsilon is used since, we would not have actual information at time `t` +#' for a data point being predicted at time `t` (only data prior to time `t` +#' should be available). +#' +#' After the censoring probability is computed, the `trunc` option is used to +#' avoid using numbers pathologically close to zero. After this, the weight is +#' computed by inverting the censoring probability. +#' +#' The `eps` argument is used to avoid information leakage when computing the +#' censoring probability. Subtracting a small number avoids using data that +#' would not be known at the time of prediction. For example, if we are making +#' survival probability predictions at `eval_time = 3.0`, we would not know the +#' about the probability of being censored at that exact time (since it has not +#' occurred yet). +#' +#' Note that if there are `n` rows in `data` and `t` time points, the resulting +#' data has `n * t` rows. Computations will not easily scale well as `t` becomes +#' large. +#' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999), +#' Assessment and comparison of prognostic classification schemes for survival +#' data. _Statist. Med._, 18: 2529-2545. +#' @export +#' @name censoring_weights +#' @keywords internal +.censoring_weights_graf <- function(object, ...) { + UseMethod(".censoring_weights_graf") +} + +#' @export +#' @rdname censoring_weights +.censoring_weights_graf.default <- function(object, ...) { + cls <- paste0("'", class(object), "'", collapse = ", ") + msg <- paste("There is no `.censoring_weights_graf()` method for objects with class(es):", + cls) + rlang::abort(msg) +} + + +#' @export +#' @rdname censoring_weights +.censoring_weights_graf.workflow <- function(object, + data, + eval_time, + rows = NULL, + predictors = NULL, + trunc = 0.05, eps = 10^-10, ...) { + if (is.null(object$fit$fit)) { + rlang::abort("The workflow does not have a model fit object.", call = FALSE) + } + .censoring_weights_graf(object$fit$fit, data, eval_time, rows, predictors, trunc, eps) +} + +#' @export +#' @rdname censoring_weights +.censoring_weights_graf.model_fit <- function(object, + data, + eval_time, + rows = NULL, + predictors = NULL, + trunc = 0.05, eps = 10^-10, ...) { + rlang::check_dots_empty() + .check_censor_model(object) + if (!is.null(predictors)) { + rlang::warn("The 'predictors' argument to the survival weighting function is not currently used.", call = FALSE) + } + eval_time <- .filter_eval_time(eval_time) + + truth <- object$preproc$y_var + if (length(truth) != 1) { + # check_outcome() tests that the outcome column is a Surv object + rlang::abort("The event time data should be in a single column with class 'Surv'", call = FALSE) + } + surv_data <- dplyr::select(data, dplyr::all_of(!!truth)) %>% setNames("surv") + .check_censored_right(surv_data$surv) + + purrr::map(eval_time, + ~ graf_weight_time(surv_data$surv, .x, eps = eps, rows = rows)) %>% + purrr::list_rbind() %>% + dplyr::mutate( + .prob_cens = predict(object$censor_probs, time = weight_time, as_vector = TRUE), + .prob_cens = trunc_probs(.prob_cens, trunc), + .weight_cens = 1 / .prob_cens + ) %>% + dplyr::select(.row, eval_time, .prob_cens, .weight_cens) +} + +# nocov end diff --git a/R/parsnip-package.R b/R/parsnip-package.R index e2033ee3b..b63fb1f35 100644 --- a/R/parsnip-package.R +++ b/R/parsnip-package.R @@ -44,7 +44,7 @@ utils::globalVariables( "compute_intercept", "remove_intercept", "estimate", "term", "call_info", "component", "component_id", "func", "tunable", "label", "pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts", - "protect", "s" + "protect", "weight_time", ".prob_cens", ".weight_cens", "s" ) ) diff --git a/man/censoring_weights.Rd b/man/censoring_weights.Rd new file mode 100644 index 000000000..cb9008f0c --- /dev/null +++ b/man/censoring_weights.Rd @@ -0,0 +1,116 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ipcw.R +\name{censoring_weights} +\alias{censoring_weights} +\alias{.censoring_weights_graf} +\alias{.censoring_weights_graf.default} +\alias{.censoring_weights_graf.workflow} +\alias{.censoring_weights_graf.model_fit} +\title{Calculations for inverse probability of censoring weights (IPCW)} +\usage{ +.censoring_weights_graf(object, ...) + +\method{.censoring_weights_graf}{default}(object, ...) + +\method{.censoring_weights_graf}{workflow}( + object, + data, + eval_time, + rows = NULL, + predictors = NULL, + trunc = 0.05, + eps = 10^-10, + ... +) + +\method{.censoring_weights_graf}{model_fit}( + object, + data, + eval_time, + rows = NULL, + predictors = NULL, + trunc = 0.05, + eps = 10^-10, + ... +) +} +\arguments{ +\item{object}{A fitted parsnip model object or fitted workflow with a mode +of "censored regression".} + +\item{data}{A data frame with a column containing a \code{\link[survival:Surv]{survival::Surv()}} object.} + +\item{eval_time}{A vector of finite, non-negative times at which to +compute the probability of censoring and the corresponding weights.} + +\item{rows}{An optional integer vector with length equal to the number of +rows in \code{data} that is used to index the original data. The default is to +use a fresh index on data (i.e. \code{1:nrow(data)}).} + +\item{predictors}{Not currently used. A potential future slot for models with +informative censoring based on columns in \code{data}.} + +\item{trunc}{A potential lower bound for the probability of censoring to avoid +very large weight values.} + +\item{eps}{A small value that is subtracted from the evaluation time when +computing the censoring probabilities. See Details below.} +} +\value{ +A tibble with columns \code{.row}, \code{eval_time}, \code{.prob_cens} (the +probability of being censored just prior to the evaluation time), and +\code{.weight_cens} (the inverse probability of censoring weight). +} +\description{ +The method of Graf \emph{et al} (1999) is used to compute weights at specific +evaluation times that can be used to help measure a model's time-dependent +performance (e.g. the time-dependent Brier score or the area under the ROC +curve). +} +\details{ +A probability that the data are censored immediately prior to a specific +time is computed. To do this, we must determine what time to +make the prediction. There are two time values for each row of the data set: +the observed time (either censored or not) and the time that the model is +being evaluated at (e.g. the survival function prediction at some time point), +which is constant across rows. . + +From Graf \emph{et al} (1999) there are three cases: +\itemize{ +\item If the observed time is a censoring time and that is before the +evaluation time, the data point should make no contribution to the +performance metric (their "category 3"). These values have a missing +value for their probability estimate (and also for their weight column). +\item If the observed time corresponds to an actual event, and that time is +prior to the evaluation time (category 1), the probability of being +censored is predicted at the observed time (minus an epsilon). +\item If the observed time is \emph{after} the evaluation time (category 2), regardless of +the status, the probability of being censored is predicted at the evaluation +time (minus an epsilon). +} + +The epsilon is used since, we would not have actual information at time \code{t} +for a data point being predicted at time \code{t} (only data prior to time \code{t} +should be available). + +After the censoring probability is computed, the \code{trunc} option is used to +avoid using numbers pathologically close to zero. After this, the weight is +computed by inverting the censoring probability. + +The \code{eps} argument is used to avoid information leakage when computing the +censoring probability. Subtracting a small number avoids using data that +would not be known at the time of prediction. For example, if we are making +survival probability predictions at \code{eval_time = 3.0}, we would not know the +about the probability of being censored at that exact time (since it has not +occurred yet). + +Note that if there are \code{n} rows in \code{data} and \code{t} time points, the resulting +data has \code{n * t} rows. Computations will not easily scale well as \code{t} becomes +large. +} +\references{ +Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999), +Assessment and comparison of prognostic classification schemes for survival +data. \emph{Statist. Med.}, 18: 2529-2545. +} +\keyword{internal} diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R index 70eb0fb2e..a9297a65a 100644 --- a/tests/testthat/helper-objects.R +++ b/tests/testthat/helper-objects.R @@ -11,3 +11,16 @@ caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 + +# ------------------------------------------------------------------------------ +# for skips + +is_tf_ok <- function() { + tf_ver <- try(tensorflow::tf_version(), silent = TRUE) + if (inherits(tf_ver, "try-error")) { + res <- FALSE + } else { + res <- !is.null(tf_ver) + } + res +} diff --git a/tests/testthat/test-ipcw.R b/tests/testthat/test-ipcw.R index 9a1b3ed3f..376d20816 100644 --- a/tests/testthat/test-ipcw.R +++ b/tests/testthat/test-ipcw.R @@ -17,8 +17,20 @@ test_that('probability truncation', { min(parsnip:::trunc_probs((1:200)/200)), 1 / 200 ) -}) + probs_1 <- (0:10) / 20 + probs_2 <- probs_1 + probs_2[3] <- NA_real_ + + expect_equal(parsnip:::trunc_probs(probs_1, 0), probs_1) + expect_equal(parsnip:::trunc_probs(probs_2, 0), probs_2) + expect_equal( + parsnip:::trunc_probs(probs_1, 0.1), + ifelse(probs_1 < 0.05 / 2, 0.05 / 2, probs_1) + ) + expect_equal(min(parsnip:::trunc_probs(probs_2, 0.1), na.rm = TRUE), 0.05 / 2) + expect_equal(is.na(parsnip:::trunc_probs(probs_2, 0.1)),is.na(probs_2)) +}) test_that('time filtering', { times_1 <- 0:10 @@ -35,8 +47,3 @@ test_that('time filtering', { expect_snapshot(error = TRUE, parsnip:::.filter_eval_time(-1)) expect_null(parsnip:::.filter_eval_time(NULL)) }) - - - - - diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R index f2d0b1dd8..fbf37f6c2 100644 --- a/tests/testthat/test_linear_reg_keras.R +++ b/tests/testthat/test_linear_reg_keras.R @@ -17,7 +17,7 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) set_tf_seed(257) @@ -87,7 +87,7 @@ test_that('model fitting', { test_that('regression prediction', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) library(keras) diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R index 6e5897f58..2fae93320 100644 --- a/tests/testthat/test_logistic_reg_keras.R +++ b/tests/testthat/test_logistic_reg_keras.R @@ -27,7 +27,7 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) set_tf_seed(257) @@ -97,7 +97,7 @@ test_that('model fitting', { test_that('classification prediction', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) library(keras) @@ -142,7 +142,7 @@ test_that('classification prediction', { test_that('classification probabilities', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) library(keras) diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 57ae81e94..939ca9f0c 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -14,7 +14,7 @@ nn_dat <- read.csv("nnet_test.txt") test_that('keras execution, classification', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) expect_error( res <- parsnip::fit( @@ -58,7 +58,7 @@ test_that('keras execution, classification', { test_that('keras classification prediction', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) library(keras) xy_fit <- parsnip::fit_xy( @@ -107,7 +107,7 @@ test_that('keras classification prediction', { test_that('keras classification probabilities', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) xy_fit <- parsnip::fit_xy( hpc_keras, @@ -158,7 +158,7 @@ bad_keras_reg <- test_that('keras execution, regression', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) expect_error( res <- parsnip::fit( @@ -186,7 +186,7 @@ test_that('keras execution, regression', { test_that('keras regression prediction', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) xy_fit <- parsnip::fit_xy( mlp(mode = "regression", hidden_units = 2, epochs = 500, penalty = .1) %>% @@ -220,7 +220,7 @@ test_that('keras regression prediction', { test_that('multivariate nnet formula', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) nnet_form <- mlp(mode = "regression", hidden_units = 3, penalty = 0.01) %>% diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R index e321e9389..698ff2caf 100644 --- a/tests/testthat/test_multinom_reg_keras.R +++ b/tests/testthat/test_multinom_reg_keras.R @@ -25,7 +25,7 @@ ctrl <- control_parsnip(verbosity = 0, catch = FALSE) test_that('model fitting', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) set_tf_seed(257) @@ -95,7 +95,7 @@ test_that('model fitting', { test_that('classification prediction', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) library(keras) @@ -140,7 +140,7 @@ test_that('classification prediction', { test_that('classification probabilities', { skip_on_cran() skip_if_not_installed("keras") - skip_if(is.null(tensorflow::tf_version())) + skip_if(!is_tf_ok()) library(keras)