diff --git a/NAMESPACE b/NAMESPACE index d2c4d1c54..b194623ba 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -55,6 +55,7 @@ S3method(print,model_spec) S3method(print,multinom_reg) S3method(print,nearest_neighbor) S3method(print,nullmodel) +S3method(print,proportional_hazards) S3method(print,rand_forest) S3method(print,surv_reg) S3method(print,survival_reg) @@ -95,6 +96,7 @@ S3method(update,mars) S3method(update,mlp) S3method(update,multinom_reg) S3method(update,nearest_neighbor) +S3method(update,proportional_hazards) S3method(update,rand_forest) S3method(update,surv_reg) S3method(update,survival_reg) @@ -176,6 +178,7 @@ export(predict_survival.model_fit) export(predict_time) export(predict_time.model_fit) export(prepare_data) +export(proportional_hazards) export(rand_forest) export(repair_call) export(req_pkgs) diff --git a/NEWS.md b/NEWS.md index 090012dfc..7b094d38a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,8 @@ * New model specification `survival_reg()` for the new mode `"censored regression"`. (#444) +* New model specification `proportional_hazards()` for the `"censored regression"` mode (#451). + # parsnip 0.1.5 * An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`. diff --git a/R/proportional_hazards.R b/R/proportional_hazards.R new file mode 100644 index 000000000..98855c9b8 --- /dev/null +++ b/R/proportional_hazards.R @@ -0,0 +1,116 @@ +#' General Interface for Proportional Hazards Models +#' +#' `proportional_hazards()` is a way to generate a _specification_ of a model +#' before fitting and allows the model to be created using different packages +#' in R. The main arguments for the model are: +#' \itemize{ +#' \item \code{penalty}: The total amount of regularization +#' in the model. Note that this must be zero for some engines. +#' \item \code{mixture}: The mixture amounts of different types of +#' regularization (see below). Note that this will be ignored for some engines. +#' } +#' These arguments are converted to their specific names at the +#' time that the model is fit. Other options and arguments can be +#' set using `set_engine()`. If left to their defaults +#' here (`NULL`), the values are taken from the underlying model +#' functions. If parameters need to be modified, `update()` can be used +#' in lieu of recreating the object from scratch. +#' +#' @param mode A single character string for the type of model. +#' Possible values for this model are "unknown", or "censored regression". +#' @inheritParams linear_reg +#' +#' @details +#' Proportional hazards models include the Cox model. +#' For `proportional_hazards()`, the mode will always be "censored regression". +#' +#' @examples +#' show_engines("proportional_hazards") +#' +#' @export +proportional_hazards <- function(mode = "censored regression", + penalty = NULL, + mixture = NULL) { + + args <- list( + penalty = enquo(penalty), + mixture = enquo(mixture) + ) + + new_model_spec( + "proportional_hazards", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + } + +#' @export +print.proportional_hazards <- function(x, ...) { + cat("Proportional Hazards Model Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if (!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + + invisible(x) +} + +# ------------------------------------------------------------------------------ + +#' @param object A proportional hazards model specification. +#' @param ... Not used for `update()`. +#' @param fresh A logical for whether the arguments should be +#' modified in-place of or replaced wholesale. +#' @examples +#' model <- proportional_hazards(penalty = 10, mixture = 0.1) +#' model +#' update(model, penalty = 1) +#' update(model, penalty = 1, fresh = TRUE) +#' @method update proportional_hazards +#' @rdname proportional_hazards +#' @export +update.proportional_hazards <- function(object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, ...) { + + eng_args <- update_engine_parameters(object$eng_args, ...) + + if (!is.null(parameters)) { + parameters <- check_final_param(parameters) + } + args <- list( + penalty = enquo(penalty), + mixture = enquo(mixture) + ) + + args <- update_main_parameters(args, parameters) + + if (fresh) { + object$args <- args + object$eng_args <- eng_args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args + } + + new_model_spec( + "proportional_hazards", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) + } diff --git a/R/proportional_hazards_data.R b/R/proportional_hazards_data.R new file mode 100644 index 000000000..2ac4407a4 --- /dev/null +++ b/R/proportional_hazards_data.R @@ -0,0 +1,5 @@ + +# parnip just contains the model specification, the engines are the censored package. + +set_new_model("proportional_hazards") +set_model_mode("proportional_hazards", "censored regression") diff --git a/man/proportional_hazards.Rd b/man/proportional_hazards.Rd new file mode 100644 index 000000000..640d50d94 --- /dev/null +++ b/man/proportional_hazards.Rd @@ -0,0 +1,78 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/proportional_hazards.R +\name{proportional_hazards} +\alias{proportional_hazards} +\alias{update.proportional_hazards} +\title{General Interface for Proportional Hazards Models} +\usage{ +proportional_hazards( + mode = "censored regression", + penalty = NULL, + mixture = NULL +) + +\method{update}{proportional_hazards}( + object, + parameters = NULL, + penalty = NULL, + mixture = NULL, + fresh = FALSE, + ... +) +} +\arguments{ +\item{mode}{A single character string for the type of model. +Possible values for this model are "unknown", or "censored regression".} + +\item{penalty}{A non-negative number representing the total +amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only). +For \code{keras} models, this corresponds to purely L2 regularization +(aka weight decay) while the other models can be a combination +of L1 and L2 (depending on the value of \code{mixture}; see below).} + +\item{mixture}{A number between zero and one (inclusive) that is the +proportion of L1 regularization (i.e. lasso) in the model. When +\code{mixture = 1}, it is a pure lasso model while \code{mixture = 0} indicates that +ridge regression is being used. (\code{glmnet} and \code{spark} only).} + +\item{object}{A proportional hazards model specification.} + +\item{parameters}{A 1-row tibble or named list with \emph{main} +parameters to update. If the individual arguments are used, +these will supersede the values in \code{parameters}. Also, using +engine arguments in this object will result in an error.} + +\item{fresh}{A logical for whether the arguments should be +modified in-place of or replaced wholesale.} + +\item{...}{Not used for \code{update()}.} +} +\description{ +\code{proportional_hazards()} is a way to generate a \emph{specification} of a model +before fitting and allows the model to be created using different packages +in R. The main arguments for the model are: +\itemize{ +\item \code{penalty}: The total amount of regularization +in the model. Note that this must be zero for some engines. +\item \code{mixture}: The mixture amounts of different types of +regularization (see below). Note that this will be ignored for some engines. +} +These arguments are converted to their specific names at the +time that the model is fit. Other options and arguments can be +set using \code{set_engine()}. If left to their defaults +here (\code{NULL}), the values are taken from the underlying model +functions. If parameters need to be modified, \code{update()} can be used +in lieu of recreating the object from scratch. +} +\details{ +Proportional hazards models include the Cox model. +For \code{proportional_hazards()}, the mode will always be "censored regression". +} +\examples{ +show_engines("proportional_hazards") + +model <- proportional_hazards(penalty = 10, mixture = 0.1) +model +update(model, penalty = 1) +update(model, penalty = 1, fresh = TRUE) +} diff --git a/tests/testthat/test-proportional_hazards.R b/tests/testthat/test-proportional_hazards.R new file mode 100644 index 000000000..894d8eeee --- /dev/null +++ b/tests/testthat/test-proportional_hazards.R @@ -0,0 +1,77 @@ + +test_that("primary arguments", { + new_empty_quosure <- function(expr) { + rlang::new_quosure(expr, env = rlang::empty_env()) + } + + ph_penalty <- proportional_hazards(penalty = 0.05) + expect_equal( + ph_penalty$args, + list(penalty = new_empty_quosure(0.05), + mixture = new_empty_quosure(NULL)) + ) + + ph_mixture <- proportional_hazards(mixture = 0.34) + expect_equal( + ph_mixture$args, + list(penalty = new_empty_quosure(NULL), + mixture = new_empty_quosure(0.34)) + ) + + ph_mixture_v <- proportional_hazards(mixture = varying()) + expect_equal( + ph_mixture_v$args, + list(penalty = new_empty_quosure(NULL), + mixture = new_empty_quosure(varying())) + ) +}) + +test_that("printing", { + expect_output( + print(proportional_hazards()), + "Proportional Hazards Model Specification \\(censored regression\\)" + ) +}) + +test_that("updating", { + new_empty_quosure <- function(expr) { + rlang::new_quosure(expr, env = rlang::empty_env()) + } + + basic <- proportional_hazards() + + update_num <- update(basic, penalty = 0.05) + expect_equal( + update_num$args, + list(penalty = new_empty_quosure(0.05), + mixture = new_empty_quosure(NULL)) + ) + + param_tibb <- tibble::tibble(penalty = 0.05) + update_tibb <- update(basic, param_tibb) + expect_equal( + update_tibb$args, + list(penalty = 0.05, + mixture = new_empty_quosure(NULL)) + ) + + param_list <- as.list(param_tibb) + update_list <- update(basic, param_list) + expect_equal( + update_list$args, + list(penalty = 0.05, + mixture = new_empty_quosure(NULL)) + ) +}) + + +test_that("bad input", { + expect_error(proportional_hazards(mode = ", classification")) +}) + +test_that("wrong fit interface", { + expect_error( + proportional_hazards() %>% fit_xy(), + "must use the formula interface" + ) +})