diff --git a/DESCRIPTION b/DESCRIPTION index 818e7219b..14535b9a0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 0.1.7.9005 +Version: 0.1.7.9006 Authors@R: c( person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"), @@ -21,6 +21,7 @@ Imports: cli, dplyr (>= 0.8.0.1), generics (>= 0.1.0.9000), + ggplot2, globals, glue, hardhat (>= 0.1.6.9001), @@ -41,7 +42,7 @@ Suggests: dials (>= 0.0.10.9001), earth, tensorflow, - ggplot2, + ggrepel, keras, kernlab, kknn, diff --git a/NAMESPACE b/NAMESPACE index a54c266fa..5236d3b7f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -165,6 +165,7 @@ export(C5.0_train) export(C5_rules) export(add_rowindex) export(augment) +export(autoplot) export(bag_mars) export(bag_tree) export(bart) @@ -314,6 +315,7 @@ importFrom(generics,glance) importFrom(generics,required_pkgs) importFrom(generics,tidy) importFrom(generics,varying_args) +importFrom(ggplot2,autoplot) importFrom(glue,glue_collapse) importFrom(hardhat,extract_fit_engine) importFrom(hardhat,extract_parameter_dials) @@ -321,6 +323,7 @@ importFrom(hardhat,extract_parameter_set_dials) importFrom(hardhat,extract_spec_parsnip) importFrom(hardhat,tune) importFrom(magrittr,"%>%") +importFrom(purrr,"%||%") importFrom(purrr,as_vector) importFrom(purrr,imap) importFrom(purrr,imap_lgl) diff --git a/NEWS.md b/NEWS.md index aef0531ab..e12ca453f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -40,10 +40,13 @@ * `varying_args()` is soft-deprecated in favor of `tune_args()`. +* An `autoplot()` method was added for glmnet objects, showing the coefficient paths versus the penalty values (#642). + * parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596). * xgboost engines now use the new `iterationrange` parameter instead of the deprecated `ntreelimit` (#656). + # parsnip 0.1.7 ## Model Specification Changes diff --git a/R/0_imports.R b/R/0_imports.R index 0ea5bb959..a5a747b36 100644 --- a/R/0_imports.R +++ b/R/0_imports.R @@ -3,7 +3,7 @@ #' @importFrom generics varying_args #' @importFrom glue glue_collapse #' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr -#' @importFrom purrr map_lgl +#' @importFrom purrr map_lgl %||% #' @importFrom rlang abort call2 caller_env current_env enquo enquos eval_tidy #' @importFrom rlang expr get_expr is_empty is_missing is_null is_quosure #' @importFrom rlang is_symbolic lgl missing_arg quo_get_expr quos sym syms @@ -16,4 +16,5 @@ #' @importFrom utils capture.output getFromNamespace globalVariables head #' @importFrom utils methods stack #' @importFrom vctrs vec_size vec_unique +#' @importFrom ggplot2 autoplot NULL diff --git a/R/aaa.R b/R/aaa.R index 9906ac037..5137a54d8 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -82,7 +82,7 @@ function(results, object) { } # ------------------------------------------------------------------------------ -# nocov +# nocov start utils::globalVariables( c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', @@ -91,7 +91,9 @@ utils::globalVariables( "max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees", "sub_neighbors", ".pred_class", "x", "y", "predictor_indicators", "compute_intercept", "remove_intercept", "estimate", "term", - "call_info", "component", "component_id", "func", "pkg", ".order", "item", "tunable") + "call_info", "component", "component_id", "func", "tunable", "label", + "pkg", ".order", "item", "tunable" + ) ) # nocov end diff --git a/R/autoplot.R b/R/autoplot.R new file mode 100644 index 000000000..f18933cdd --- /dev/null +++ b/R/autoplot.R @@ -0,0 +1,170 @@ +#' Create a ggplot for a model object +#' +#' This method provides a good visualization method for model results. +#' Currently, only methods for glmnet models are implemented. +#' +#' @param object A model fit object. +#' @param min_penalty A single, non-negative number for the smallest penalty +#' value that should be shown in the plot. If left `NULL`, the whole data +#' range is used. +#' @param best_penalty A single, non-negative number that will show a vertical +#' line marker. If left `NULL`, no line is shown. When this argument is used, +#' the \pkg{ggrepl} package is required. +#' @param top_n A non-negative integer for how many model predictors to label. +#' The top predictors are ranked by their absolute coefficient value. For +#' multinomial or multivariate models, the `top_n` terms are selected within +#' class or response, respectively. +#' @param ... For [autoplot.glmnet()], options to pass to +#' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored. +#' @return A ggplot object with penalty on the x-axis and coefficients on the +#' y-axis. For multinomial or multivariate models, the plot is faceted. +#' @details The \pkg{glmnet} package will need to be attached or loaded for +#' its `autoplot()` method to work correctly. +#' +# registered in zzz.R +autoplot.model_fit <- function(object, ...) { + autoplot(object$fit, ...) +} + +# glmnet is not a formal dependency here. +# unit tests are located at https://github.com/tidymodels/extratests +# nocov start + +# registered in zzz.R +#' @rdname autoplot.model_fit +autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL, + top_n = 3L) { + autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...) +} + + +map_glmnet_coefs <- function(x) { + coefs <- coef(x) + # If parsnip is used to fit the model, glmnet should be attached and this will + # work. If an object is loaded from a new session, they will need to load the + # package. + if (is.null(coefs)) { + rlang::abort("Please load the glmnet package before running `autoplot()`.") + } + p <- x$dim[1] + if (is.list(coefs)) { + classes <- names(coefs) + coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda) + coefs <- purrr::map2_dfr(coefs, classes, ~ dplyr::mutate(.x, class = .y)) + } else { + coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda) + } + coefs +} + +reformat_coefs <- function(x, p, penalty) { + x <- as.matrix(x) + num_estimates <- nrow(x) + if (num_estimates > p) { + # The intercept is first + x <- x[-(num_estimates - p),, drop = FALSE] + } + term_lab <- rownames(x) + colnames(x) <- paste(seq_along(penalty)) + x <- tibble::as_tibble(x) + x$term <- term_lab + x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate") + x$penalty <- rep(penalty, p) + x$index <- NULL + x +} + +top_coefs <- function(x, top_n = 5) { + x %>% + dplyr::group_by(term) %>% + dplyr::arrange(term, dplyr::desc(abs(estimate))) %>% + dplyr::slice(1) %>% + dplyr::ungroup() %>% + dplyr::arrange(dplyr::desc(abs(estimate))) %>% + dplyr::slice(1:top_n) +} + +autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { + check_penalty_value(min_penalty) + + tidy_coefs <- + map_glmnet_coefs(x) %>% + dplyr::filter(penalty >= min_penalty) + + actual_min_penalty <- min(tidy_coefs$penalty) + num_terms <- length(unique(tidy_coefs$term)) + top_n <- min(top_n[1], num_terms) + if (top_n < 0) { + top_n <- 0 + } + + has_groups <- any(names(tidy_coefs) == "class") + + # Keep the large values + if (has_groups) { + label_coefs <- + tidy_coefs %>% + dplyr::group_nest(class) %>% + dplyr::mutate(data = purrr::map(data, top_coefs, top_n = top_n)) %>% + dplyr::select(class, data) %>% + tidyr::unnest(cols = data) + } else { + if (is.null(best_penalty)) { + label_coefs <- tidy_coefs %>% + top_coefs(top_n) + } else { + label_coefs <- tidy_coefs %>% + dplyr::filter(penalty > best_penalty) %>% + dplyr::filter(penalty == min(penalty)) %>% + dplyr::arrange(dplyr::desc(abs(estimate))) %>% + dplyr::slice(seq_len(top_n)) + } + } + + label_coefs <- + label_coefs %>% + dplyr::mutate(penalty = best_penalty %||% actual_min_penalty) %>% + dplyr::mutate(label = gsub(".pred_no_", "", term)) + + # plot the paths and highlight the large values + p <- + tidy_coefs %>% + ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term)) + + if (has_groups) { + p <- p + ggplot2::facet_wrap(~ class) + } + + if (!is.null(best_penalty)) { + check_penalty_value(best_penalty) + p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3) + } + + p <- p + + ggplot2::geom_line(alpha = .4, show.legend = FALSE) + + ggplot2::scale_x_log10() + + if(top_n > 0) { + rlang::check_installed("ggrepel") + p <- p + + ggrepel::geom_label_repel( + data = label_coefs, + ggplot2::aes(y = estimate, label = label), + show.legend = FALSE, + ... + ) + } + p +} + +check_penalty_value <- function(x) { + cl <- match.call() + arg_val <- as.character(cl$x) + if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) { + msg <- paste0("Argument '", arg_val, "' should be a single, non-negative value.") + rlang::abort(msg) + } + invisible(x) +} + +# nocov end diff --git a/R/reexports.R b/R/reexports.R index 3a1a22227..e26510794 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -1,3 +1,6 @@ +#' @importFrom ggplot2 autoplot +#' @export +ggplot2::autoplot #' @importFrom magrittr %>% #' @export diff --git a/R/zzz.R b/R/zzz.R index 1cf9c353c..60e919b37 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -12,6 +12,9 @@ s3_register("generics::required_pkgs", "model_fit") s3_register("generics::required_pkgs", "model_spec") + s3_register("ggplot2::autoplot", "model_fit") + s3_register("ggplot2::autoplot", "glmnet") + # - If tune isn't installed, register the method (`packageVersion()` will error here) # - If tune >= 0.1.6.9001 is installed, register the method should_register_tune_args_method <- tryCatch( @@ -108,61 +111,3 @@ s3_register <- function(generic, class, method = NULL) { # nocov end - - -#' ## nocov start -#' -#' data_obj <- ls(pattern = "_data$") -#' data_obj <- data_obj[data_obj != "prepare_data"] -#' -#' data_names <- -#' map_dfr( -#' data_obj, -#' function(x) { -#' module <- names(get(x)) -#' if (length(module) > 1) { -#' module <- table(module) -#' module <- as_tibble(module) -#' module$object <- x -#' module -#' } else -#' module <- NULL -#' module -#' } -#' ) -#' -#' if(any(data_names$n > 1)) { -#' print(data_names[data_names$n > 1,]) -#' rlang::abort("Some models have duplicate module names.") -#' } -#' rm(data_names) -#' -#' # ------------------------------------------------------------------------------ -#' -#' engine_objects <- ls(pattern = "_engines$") -#' engine_objects <- engine_objects[engine_objects != "possible_engines"] -#' -#' get_engine_info <- function(x) { -#' y <- x -#' y <- get(y) -#' z <- stack(y) -#' z$mode <- rownames(y) -#' z$model <- gsub("_engines$", "", x) -#' z$object <- x -#' z <- z[z$values,] -#' z <- z[z$mode != "unknown",] -#' z$values <- NULL -#' names(z)[1] <- "engine" -#' z$engine <- as.character(z$engine) -#' z -#' } -#' -#' engine_info <- -#' purrr::map_df( -#' parsnip:::engine_objects, -#' get_engine_info -#' ) -#' -#' rm(engine_objects) -#' -#' ## nocov end diff --git a/_pkgdown.yml b/_pkgdown.yml index 8897bed6d..cb99d5bbd 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -48,6 +48,7 @@ reference: - svm_rbf - title: Infrastructure contents: + - autoplot.model_fit - add_rowindex - augment.model_fit - descriptors diff --git a/man/autoplot.model_fit.Rd b/man/autoplot.model_fit.Rd new file mode 100644 index 000000000..1f0c5ab7e --- /dev/null +++ b/man/autoplot.model_fit.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/autoplot.R +\name{autoplot.model_fit} +\alias{autoplot.model_fit} +\alias{autoplot.glmnet} +\title{Create a ggplot for a model object} +\usage{ +\method{autoplot}{model_fit}(object, ...) + +\method{autoplot}{glmnet}(object, ..., min_penalty = 0, best_penalty = NULL, top_n = 3L) +} +\arguments{ +\item{object}{A model fit object.} + +\item{...}{For \code{\link[=autoplot.glmnet]{autoplot.glmnet()}}, options to pass to +\code{\link[ggrepel:geom_text_repel]{ggrepel::geom_label_repel()}}. Otherwise, this argument is ignored.} + +\item{min_penalty}{A single, non-negative number for the smallest penalty +value that should be shown in the plot. If left \code{NULL}, the whole data +range is used.} + +\item{best_penalty}{A single, non-negative number that will show a vertical +line marker. If left \code{NULL}, no line is shown. When this argument is used, +the \pkg{ggrepl} package is required.} + +\item{top_n}{A non-negative integer for how many model predictors to label. +The top predictors are ranked by their absolute coefficient value. For +multinomial or multivariate models, the \code{top_n} terms are selected within +class or response, respectively.} +} +\value{ +A ggplot object with penalty on the x-axis and coefficients on the +y-axis. For multinomial or multivariate models, the plot is faceted. +} +\description{ +This method provides a good visualization method for model results. +Currently, only methods for glmnet models are implemented. +} +\details{ +The \pkg{glmnet} package will need to be attached or loaded for +its \code{autoplot()} method to work correctly. +} diff --git a/man/details_proportional_hazards_glmnet.Rd b/man/details_proportional_hazards_glmnet.Rd index ae235edda..eaa584b45 100644 --- a/man/details_proportional_hazards_glmnet.Rd +++ b/man/details_proportional_hazards_glmnet.Rd @@ -116,7 +116,7 @@ tidymodels does not treat different models differently when computing performance metrics. To standardize across model types, the default for proportional hazards models is to have \emph{increasing values with time}. As a result, the sign of the linear predictor will be the opposite of the -value produced by the \code{predict()} method in the package. +value produced by the \code{predict()} method in the engine package. This behavior can be changed by using the \code{increasing} argument when calling \code{predict()} on a model object. diff --git a/man/details_proportional_hazards_survival.Rd b/man/details_proportional_hazards_survival.Rd index 5d6e57a2e..84fe2fb40 100644 --- a/man/details_proportional_hazards_survival.Rd +++ b/man/details_proportional_hazards_survival.Rd @@ -100,7 +100,7 @@ tidymodels does not treat different models differently when computing performance metrics. To standardize across model types, the default for proportional hazards models is to have \emph{increasing values with time}. As a result, the sign of the linear predictor will be the opposite of the -value produced by the \code{predict()} method in the package. +value produced by the \code{predict()} method in the engine package. This behavior can be changed by using the \code{increasing} argument when calling \code{predict()} on a model object. diff --git a/man/reexports.Rd b/man/reexports.Rd index f0ecfc522..3a5f8c898 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -3,6 +3,7 @@ \docType{import} \name{reexports} \alias{reexports} +\alias{autoplot} \alias{\%>\%} \alias{fit} \alias{fit_xy} @@ -25,6 +26,8 @@ below to see their documentation. \describe{ \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{fit}}, \code{\link[generics]{fit_xy}}, \code{\link[generics]{glance}}, \code{\link[generics]{required_pkgs}}, \code{\link[generics]{tidy}}, \code{\link[generics]{varying_args}}} + \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} diff --git a/tests/testthat/test-augment.R b/tests/testthat/test_augment.R similarity index 100% rename from tests/testthat/test-augment.R rename to tests/testthat/test_augment.R diff --git a/tests/testthat/test-extract.R b/tests/testthat/test_extract.R similarity index 100% rename from tests/testthat/test-extract.R rename to tests/testthat/test_extract.R diff --git a/tests/testthat/test-proportional_hazards.R b/tests/testthat/test_proportional_hazards.R similarity index 100% rename from tests/testthat/test-proportional_hazards.R rename to tests/testthat/test_proportional_hazards.R diff --git a/tests/testthat/test-survival_reg.R b/tests/testthat/test_survival_reg.R similarity index 100% rename from tests/testthat/test-survival_reg.R rename to tests/testthat/test_survival_reg.R