From 5ce561dafe136c93454d6df3b63d6d521d3877d1 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 1 Feb 2022 21:56:08 -0500 Subject: [PATCH 01/13] glmnet autoplot method --- DESCRIPTION | 5 +- NAMESPACE | 3 + NEWS.md | 3 + R/0_imports.R | 1 + R/aaa.R | 3 +- R/autoplot.R | 156 ++++++++++++++++++++++++++++++++++++++ R/zzz.R | 61 +-------------- man/autoplot.model_fit.Rd | 37 +++++++++ 8 files changed, 208 insertions(+), 61 deletions(-) create mode 100644 R/autoplot.R create mode 100644 man/autoplot.model_fit.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 716857000..5d402e28e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 0.1.7.9005 +Version: 0.1.7.9006 Authors@R: c( person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"), @@ -20,6 +20,8 @@ Depends: Imports: dplyr (>= 0.8.0.1), generics (>= 0.1.0.9000), + ggplot2, + ggrepel, globals, glue, hardhat (>= 0.1.6.9001), @@ -39,7 +41,6 @@ Suggests: covr, dials (>= 0.0.10.9001), earth, - ggplot2, keras, kernlab, kknn, diff --git a/NAMESPACE b/NAMESPACE index 1ef6d6736..21cef744c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -165,6 +165,8 @@ export(C5.0_train) export(C5_rules) export(add_rowindex) export(augment) +export(autoplot.glmnet) +export(autoplot.model_fit) export(bag_mars) export(bag_tree) export(bart) @@ -309,6 +311,7 @@ importFrom(generics,glance) importFrom(generics,required_pkgs) importFrom(generics,tidy) importFrom(generics,varying_args) +importFrom(ggplot2,autoplot) importFrom(glue,glue_collapse) importFrom(hardhat,extract_fit_engine) importFrom(hardhat,extract_parameter_dials) diff --git a/NEWS.md b/NEWS.md index f16dbcf6d..aa12045aa 100644 --- a/NEWS.md +++ b/NEWS.md @@ -40,6 +40,9 @@ * `varying_args()` is soft-deprecated in favor of `tune_args()`. +* An `autoplot()` method was added for glmnet objects, showing the coefficient paths versus the penalty values (#642). + + # parsnip 0.1.7 ## Model Specification Changes diff --git a/R/0_imports.R b/R/0_imports.R index 0ea5bb959..1fd648e5d 100644 --- a/R/0_imports.R +++ b/R/0_imports.R @@ -16,4 +16,5 @@ #' @importFrom utils capture.output getFromNamespace globalVariables head #' @importFrom utils methods stack #' @importFrom vctrs vec_size vec_unique +#' @importFrom ggplot2 autoplot NULL diff --git a/R/aaa.R b/R/aaa.R index d55ce51d8..f50f4f71f 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -91,7 +91,8 @@ utils::globalVariables( "max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees", "sub_neighbors", ".pred_class", "x", "y", "predictor_indicators", "compute_intercept", "remove_intercept", "estimate", "term", - "call_info", "component", "component_id", "func", "tunable") + "call_info", "component", "component_id", "func", "tunable", "label" + ) ) # nocov end diff --git a/R/autoplot.R b/R/autoplot.R new file mode 100644 index 000000000..d5425907b --- /dev/null +++ b/R/autoplot.R @@ -0,0 +1,156 @@ +#' Create a ggplot for a model object +#' +#' This method provides a good visualization method for model results. +#' Currently, only methods for glmnet models are implemented. +#' +#' @param object A model fit object. +#' @param min_penalty A single, non-negative number for the smallest penalty +#' value that should be shown in the plot. If left `NULL`, the whole data +#' range is used. +#' @param best_penalty A single, non-negative number that will show a vertical +#' line marker. If left `NULL`, no line is shown. +#' @param top_n A non-negative integer for how many model predictors to label. +#' The top predictors are ranked by their absolute coefficient value. For +#' multinomial or multivariate models, the `top_n` terms are selected within +#' class or response, respectively. +#' @param ... For [autoplot.glmnet()], options to pass to +#' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored. +#' @return A ggplot object with penalty on the x-axis and coefficients on the +#' y-axis. For multinomial or multivariate models, the plot is faceted. +#' +#' @export +autoplot.model_fit <- function(object, ...) { + autoplot(object$fit, ...) +} + +# glmnet is not a formal dpendency here. +# unit tests are located at https://github.com/tidymodels/extratests + +# nocov start + +#' @export +#' @rdname autoplot.model_fit +autoplot.glmnet <- function(object, min_penalty = 0, best_penalty = NULL, + top_n = 3L, ...) { + autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...) +} + + +map_glmnet_coefs <- function(x) { + coefs <- coef(x) + p <- x$dim[1] + if (is.list(coefs)) { + classes <- names(coefs) + coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda) + coefs <- purrr::map2_dfr(coefs, classes, ~ .x %>% dplyr::mutate(class = .y)) + } else { + coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda) + } + coefs +} + +reformat_coefs <- function(x, p, penalty) { + x <- as.matrix(x) + num_estimates <- nrow(x) + if (num_estimates > p) { + # The intercept is first + x <- x[-(num_estimates - p),, drop = FALSE] + } + term_lab <- rownames(x) + colnames(x) <- paste(seq_along(penalty)) + x <- tibble::as_tibble(x) + x$term <- term_lab + x <- tidyr::pivot_longer(x, cols = c(-term), names_to = "index", values_to = "estimate") + x$penalty <- rep(penalty, p) + x$index <- NULL + x +} + +top_coefs <- function(x, top_n = 5) { + x %>% + dplyr::group_by(term) %>% + dplyr::arrange(term, dplyr::desc(abs(estimate))) %>% + dplyr::slice(1) %>% + dplyr::ungroup() %>% + dplyr::arrange(dplyr::desc(abs(estimate))) %>% + dplyr::slice(1:top_n) +} + +autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { + if (!is.null(min_penalty)) { + check_penalty_value(min_penalty) + } + + tidy_coefs <- + map_glmnet_coefs(x) %>% + dplyr::filter(penalty >= min_penalty) + + actual_min_penalty <- min(tidy_coefs$penalty) + num_terms <- length(unique(tidy_coefs$term)) + top_n <- min(top_n[1], num_terms) + if (top_n < 0) { + top_n <- 0 + } + + has_groups <- any(names(tidy_coefs) %in% c("class")) + + # Keep the large values + if (has_groups) { + label_coefs <- + tidy_coefs %>% + dplyr::group_nest(class) %>% + dplyr::mutate(data = purrr::map(data, top_coefs, top_n = top_n)) %>% + dplyr::select(class, data) %>% + tidyr::unnest(cols = data) + } else { + label_coefs <- + tidy_coefs %>% + top_coefs(top_n) + } + + label_coefs <- + label_coefs %>% + dplyr::mutate(penalty = actual_min_penalty) %>% + dplyr::mutate(label = gsub(".pred_no_", "", term)) + + # plot the paths and highlight the large values + p <- + tidy_coefs %>% + ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term)) + + if (has_groups) { + p <- p + ggplot2::facet_wrap(~ class) + } + + if (!is.null(best_penalty)) { + check_penalty_value(best_penalty) + p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3) + } + + p <- p + + ggplot2::geom_line(alpha = .4, show.legend = FALSE) + + ggplot2::scale_x_log10() + + if(top_n > 0) { + p <- p + + ggrepel::geom_label_repel( + data = label_coefs, + ggplot2::aes(y = estimate, label = label), + show.legend = FALSE, + ... + ) + } + p +} + +check_penalty_value <- function(x) { + cl <- match.call() + arg_val <- as.character(cl$x) + if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) { + msg <- paste0("Argument '", arg_val, "' should be a single, non-negative value.") + rlang::abort(msg) + } + invisible(x) +} + +# nocov stop diff --git a/R/zzz.R b/R/zzz.R index 1cf9c353c..60e919b37 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -12,6 +12,9 @@ s3_register("generics::required_pkgs", "model_fit") s3_register("generics::required_pkgs", "model_spec") + s3_register("ggplot2::autoplot", "model_fit") + s3_register("ggplot2::autoplot", "glmnet") + # - If tune isn't installed, register the method (`packageVersion()` will error here) # - If tune >= 0.1.6.9001 is installed, register the method should_register_tune_args_method <- tryCatch( @@ -108,61 +111,3 @@ s3_register <- function(generic, class, method = NULL) { # nocov end - - -#' ## nocov start -#' -#' data_obj <- ls(pattern = "_data$") -#' data_obj <- data_obj[data_obj != "prepare_data"] -#' -#' data_names <- -#' map_dfr( -#' data_obj, -#' function(x) { -#' module <- names(get(x)) -#' if (length(module) > 1) { -#' module <- table(module) -#' module <- as_tibble(module) -#' module$object <- x -#' module -#' } else -#' module <- NULL -#' module -#' } -#' ) -#' -#' if(any(data_names$n > 1)) { -#' print(data_names[data_names$n > 1,]) -#' rlang::abort("Some models have duplicate module names.") -#' } -#' rm(data_names) -#' -#' # ------------------------------------------------------------------------------ -#' -#' engine_objects <- ls(pattern = "_engines$") -#' engine_objects <- engine_objects[engine_objects != "possible_engines"] -#' -#' get_engine_info <- function(x) { -#' y <- x -#' y <- get(y) -#' z <- stack(y) -#' z$mode <- rownames(y) -#' z$model <- gsub("_engines$", "", x) -#' z$object <- x -#' z <- z[z$values,] -#' z <- z[z$mode != "unknown",] -#' z$values <- NULL -#' names(z)[1] <- "engine" -#' z$engine <- as.character(z$engine) -#' z -#' } -#' -#' engine_info <- -#' purrr::map_df( -#' parsnip:::engine_objects, -#' get_engine_info -#' ) -#' -#' rm(engine_objects) -#' -#' ## nocov end diff --git a/man/autoplot.model_fit.Rd b/man/autoplot.model_fit.Rd new file mode 100644 index 000000000..1618b4213 --- /dev/null +++ b/man/autoplot.model_fit.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/autoplot.R +\name{autoplot.model_fit} +\alias{autoplot.model_fit} +\alias{autoplot.glmnet} +\title{Create a ggplot for a model object} +\usage{ +autoplot.model_fit(object, ...) + +autoplot.glmnet(object, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) +} +\arguments{ +\item{object}{A model fit object.} + +\item{...}{For \code{\link[=autoplot.glmnet]{autoplot.glmnet()}}, options to pass to +\code{\link[ggrepel:geom_text_repel]{ggrepel::geom_label_repel()}}. Otherwise, this argument is ignored.} + +\item{min_penalty}{A single, non-negative number for the smallest penalty +value that should be shown in the plot. If left \code{NULL}, the whole data +range is used.} + +\item{best_penalty}{A single, non-negative number that will show a vertical +line marker. If left \code{NULL}, no line is shown.} + +\item{top_n}{A non-negative integer for how many model predictors to label. +The top predictors are ranked by their absolute coefficient value. For +multinomial or multivariate models, the \code{top_n} terms are selected within +class or response, respectively.} +} +\value{ +A ggplot object with penalty on the x-axis and coefficients on the +y-axis. For multinomial or multivariate models, the plot is faceted. +} +\description{ +This method provides a good visualization method for model results. +Currently, only methods for glmnet models are implemented. +} From 3e0b4aabd00cf3136d09fcb23fcc734ace997418 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 1 Feb 2022 22:04:05 -0500 Subject: [PATCH 02/13] workflow methods --- NAMESPACE | 5 +++-- R/autoplot.R | 6 ++++++ man/autoplot.model_fit.Rd | 7 +++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 21cef744c..9581c7fe6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,9 @@ # Generated by roxygen2: do not edit by hand S3method(augment,model_fit) +S3method(autoplot,glmnet) +S3method(autoplot,model_fit) +S3method(autoplot,workflow) S3method(extract_fit_engine,model_fit) S3method(extract_parameter_dials,model_spec) S3method(extract_parameter_set_dials,model_spec) @@ -165,8 +168,6 @@ export(C5.0_train) export(C5_rules) export(add_rowindex) export(augment) -export(autoplot.glmnet) -export(autoplot.model_fit) export(bag_mars) export(bag_tree) export(bart) diff --git a/R/autoplot.R b/R/autoplot.R index d5425907b..ff7745a2d 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -23,6 +23,12 @@ autoplot.model_fit <- function(object, ...) { autoplot(object$fit, ...) } +#' @export +#' @rdname autoplot.model_fit +autoplot.workflow <- function(object, ...) { + object %>% extract_fit_engine %>% autoplot(...) +} + # glmnet is not a formal dpendency here. # unit tests are located at https://github.com/tidymodels/extratests diff --git a/man/autoplot.model_fit.Rd b/man/autoplot.model_fit.Rd index 1618b4213..9fae91c37 100644 --- a/man/autoplot.model_fit.Rd +++ b/man/autoplot.model_fit.Rd @@ -2,12 +2,15 @@ % Please edit documentation in R/autoplot.R \name{autoplot.model_fit} \alias{autoplot.model_fit} +\alias{autoplot.workflow} \alias{autoplot.glmnet} \title{Create a ggplot for a model object} \usage{ -autoplot.model_fit(object, ...) +\method{autoplot}{model_fit}(object, ...) -autoplot.glmnet(object, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) +\method{autoplot}{workflow}(object, ...) + +\method{autoplot}{glmnet}(object, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) } \arguments{ \item{object}{A model fit object.} From 880ff96d27befa11c7faf014c468fa5c03306985 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 2 Feb 2022 05:52:35 -0500 Subject: [PATCH 03/13] add pkgdown entry --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index 8897bed6d..cb99d5bbd 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -48,6 +48,7 @@ reference: - svm_rbf - title: Infrastructure contents: + - autoplot.model_fit - add_rowindex - augment.model_fit - descriptors From 783d32316d0ff1e47aa543d67614cd58bc32f2ee Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 8 Feb 2022 17:12:46 -0800 Subject: [PATCH 04/13] makes labels appear on best_penalty line if present --- R/autoplot.R | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index ff7745a2d..f1ac44259 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -109,14 +109,21 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, dplyr::select(class, data) %>% tidyr::unnest(cols = data) } else { - label_coefs <- - tidy_coefs %>% - top_coefs(top_n) + if (is.null(best_penalty)) { + label_coefs <- tidy_coefs %>% + top_coefs(top_n) + } else { + label_coefs <- tidy_coefs %>% + filter(penalty > best_penalty) %>% + filter(penalty == min(penalty)) %>% + arrange(desc(abs(estimate))) %>% + slice(seq_len(top_n)) + } } label_coefs <- label_coefs %>% - dplyr::mutate(penalty = actual_min_penalty) %>% + dplyr::mutate(penalty = best_penalty %||% actual_min_penalty) %>% dplyr::mutate(label = gsub(".pred_no_", "", term)) # plot the paths and highlight the large values From 86b9448782d23a8dce39a62b91182759c1a65433 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 16:26:33 -0500 Subject: [PATCH 05/13] Apply suggestions from code review Co-authored-by: Davis Vaughan --- R/autoplot.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index f1ac44259..92990642c 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -48,7 +48,7 @@ map_glmnet_coefs <- function(x) { if (is.list(coefs)) { classes <- names(coefs) coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda) - coefs <- purrr::map2_dfr(coefs, classes, ~ .x %>% dplyr::mutate(class = .y)) + coefs <- purrr::map2_dfr(coefs, classes, ~ dplyr::mutate(.x, class = .y)) } else { coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda) } @@ -66,7 +66,7 @@ reformat_coefs <- function(x, p, penalty) { colnames(x) <- paste(seq_along(penalty)) x <- tibble::as_tibble(x) x$term <- term_lab - x <- tidyr::pivot_longer(x, cols = c(-term), names_to = "index", values_to = "estimate") + x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate") x$penalty <- rep(penalty, p) x$index <- NULL x @@ -98,7 +98,7 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, top_n <- 0 } - has_groups <- any(names(tidy_coefs) %in% c("class")) + has_groups <- any(names(tidy_coefs) == "class") # Keep the large values if (has_groups) { From c11c783e6091045bf68ed65c98f666ed95d4bd80 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 19:06:52 -0500 Subject: [PATCH 06/13] move ggrepl to suggests --- DESCRIPTION | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 15f2e0fc0..91e7b6fa3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -22,7 +22,6 @@ Imports: dplyr (>= 0.8.0.1), generics (>= 0.1.0.9000), ggplot2, - ggrepel, globals, glue, hardhat (>= 0.1.6.9001), @@ -43,7 +42,7 @@ Suggests: dials (>= 0.0.10.9001), earth, tensorflow, - ggplot2, + ggrepel, keras, kernlab, kknn, From c1ec27295ec49af6fb06a525e4922c240109068f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 19:07:10 -0500 Subject: [PATCH 07/13] doc updates --- man/details_proportional_hazards_glmnet.Rd | 2 +- man/details_proportional_hazards_survival.Rd | 2 +- man/reexports.Rd | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/man/details_proportional_hazards_glmnet.Rd b/man/details_proportional_hazards_glmnet.Rd index ae235edda..eaa584b45 100644 --- a/man/details_proportional_hazards_glmnet.Rd +++ b/man/details_proportional_hazards_glmnet.Rd @@ -116,7 +116,7 @@ tidymodels does not treat different models differently when computing performance metrics. To standardize across model types, the default for proportional hazards models is to have \emph{increasing values with time}. As a result, the sign of the linear predictor will be the opposite of the -value produced by the \code{predict()} method in the package. +value produced by the \code{predict()} method in the engine package. This behavior can be changed by using the \code{increasing} argument when calling \code{predict()} on a model object. diff --git a/man/details_proportional_hazards_survival.Rd b/man/details_proportional_hazards_survival.Rd index 5d6e57a2e..84fe2fb40 100644 --- a/man/details_proportional_hazards_survival.Rd +++ b/man/details_proportional_hazards_survival.Rd @@ -100,7 +100,7 @@ tidymodels does not treat different models differently when computing performance metrics. To standardize across model types, the default for proportional hazards models is to have \emph{increasing values with time}. As a result, the sign of the linear predictor will be the opposite of the -value produced by the \code{predict()} method in the package. +value produced by the \code{predict()} method in the engine package. This behavior can be changed by using the \code{increasing} argument when calling \code{predict()} on a model object. diff --git a/man/reexports.Rd b/man/reexports.Rd index f0ecfc522..3a5f8c898 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -3,6 +3,7 @@ \docType{import} \name{reexports} \alias{reexports} +\alias{autoplot} \alias{\%>\%} \alias{fit} \alias{fit_xy} @@ -25,6 +26,8 @@ below to see their documentation. \describe{ \item{generics}{\code{\link[generics]{augment}}, \code{\link[generics]{fit}}, \code{\link[generics]{fit_xy}}, \code{\link[generics]{glance}}, \code{\link[generics]{required_pkgs}}, \code{\link[generics]{tidy}}, \code{\link[generics]{varying_args}}} + \item{ggplot2}{\code{\link[ggplot2]{autoplot}}} + \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} From e10a6b96693cad5fc5516744c6a7998b04c43194 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 19:07:34 -0500 Subject: [PATCH 08/13] anmespace functions and check for glmnet package --- NAMESPACE | 5 ++--- R/0_imports.R | 2 +- R/autoplot.R | 23 +++++++++++++++-------- R/reexports.R | 3 +++ 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 2798c6d40..5236d3b7f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,9 +1,6 @@ # Generated by roxygen2: do not edit by hand S3method(augment,model_fit) -S3method(autoplot,glmnet) -S3method(autoplot,model_fit) -S3method(autoplot,workflow) S3method(extract_fit_engine,model_fit) S3method(extract_parameter_dials,model_spec) S3method(extract_parameter_set_dials,model_spec) @@ -168,6 +165,7 @@ export(C5.0_train) export(C5_rules) export(add_rowindex) export(augment) +export(autoplot) export(bag_mars) export(bag_tree) export(bart) @@ -325,6 +323,7 @@ importFrom(hardhat,extract_parameter_set_dials) importFrom(hardhat,extract_spec_parsnip) importFrom(hardhat,tune) importFrom(magrittr,"%>%") +importFrom(purrr,"%||%") importFrom(purrr,as_vector) importFrom(purrr,imap) importFrom(purrr,imap_lgl) diff --git a/R/0_imports.R b/R/0_imports.R index 1fd648e5d..a5a747b36 100644 --- a/R/0_imports.R +++ b/R/0_imports.R @@ -3,7 +3,7 @@ #' @importFrom generics varying_args #' @importFrom glue glue_collapse #' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr -#' @importFrom purrr map_lgl +#' @importFrom purrr map_lgl %||% #' @importFrom rlang abort call2 caller_env current_env enquo enquos eval_tidy #' @importFrom rlang expr get_expr is_empty is_missing is_null is_quosure #' @importFrom rlang is_symbolic lgl missing_arg quo_get_expr quos sym syms diff --git a/R/autoplot.R b/R/autoplot.R index 92990642c..6e638cec2 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -18,23 +18,23 @@ #' @return A ggplot object with penalty on the x-axis and coefficients on the #' y-axis. For multinomial or multivariate models, the plot is faceted. #' -#' @export +# registered in zzz.R autoplot.model_fit <- function(object, ...) { autoplot(object$fit, ...) } -#' @export +# registered in zzz.R #' @rdname autoplot.model_fit autoplot.workflow <- function(object, ...) { object %>% extract_fit_engine %>% autoplot(...) } -# glmnet is not a formal dpendency here. +# glmnet is not a formal dependency here. # unit tests are located at https://github.com/tidymodels/extratests # nocov start -#' @export +# registered in zzz.R #' @rdname autoplot.model_fit autoplot.glmnet <- function(object, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { @@ -44,6 +44,12 @@ autoplot.glmnet <- function(object, min_penalty = 0, best_penalty = NULL, map_glmnet_coefs <- function(x) { coefs <- coef(x) + # If parsnip is used to fit the model, glmnet should be attached and this will + # work. If an object is loaded from a new session, they will need to load the + # package. + if (is.null(coefs)) { + rlang::abort("Please load the glmnet package before running `autoplot()`.") + } p <- x$dim[1] if (is.list(coefs)) { classes <- names(coefs) @@ -114,10 +120,10 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, top_coefs(top_n) } else { label_coefs <- tidy_coefs %>% - filter(penalty > best_penalty) %>% - filter(penalty == min(penalty)) %>% - arrange(desc(abs(estimate))) %>% - slice(seq_len(top_n)) + dplyr::filter(penalty > best_penalty) %>% + dplyr::filter(penalty == min(penalty)) %>% + dplyr::arrange(dplyr::desc(abs(estimate))) %>% + dplyr::slice(seq_len(top_n)) } } @@ -145,6 +151,7 @@ autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ggplot2::scale_x_log10() if(top_n > 0) { + rlang::check_installed("ggrepel") p <- p + ggrepel::geom_label_repel( data = label_coefs, diff --git a/R/reexports.R b/R/reexports.R index 3a1a22227..e26510794 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -1,3 +1,6 @@ +#' @importFrom ggplot2 autoplot +#' @export +ggplot2::autoplot #' @importFrom magrittr %>% #' @export From 38db0d0b843aac8393373ecc5f850064b07cc59f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 19:10:00 -0500 Subject: [PATCH 09/13] remove workflow method --- R/autoplot.R | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 6e638cec2..1cbeaafed 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -17,21 +17,16 @@ #' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored. #' @return A ggplot object with penalty on the x-axis and coefficients on the #' y-axis. For multinomial or multivariate models, the plot is faceted. +#' @details The \pkg{glmnet} package will need to be attached or loaded for +#' its `autoplot()` method to work correctly. #' # registered in zzz.R autoplot.model_fit <- function(object, ...) { autoplot(object$fit, ...) } -# registered in zzz.R -#' @rdname autoplot.model_fit -autoplot.workflow <- function(object, ...) { - object %>% extract_fit_engine %>% autoplot(...) -} - # glmnet is not a formal dependency here. # unit tests are located at https://github.com/tidymodels/extratests - # nocov start # registered in zzz.R @@ -89,9 +84,7 @@ top_coefs <- function(x, top_n = 5) { } autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) { - if (!is.null(min_penalty)) { - check_penalty_value(min_penalty) - } + check_penalty_value(min_penalty) tidy_coefs <- map_glmnet_coefs(x) %>% From 45c796b4e3d4d28296c7b5a26df7a97f8a8c5b44 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 21:28:24 -0500 Subject: [PATCH 10/13] move ... up in order --- R/autoplot.R | 4 ++-- man/autoplot.model_fit.Rd | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 1cbeaafed..96f805b16 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -31,8 +31,8 @@ autoplot.model_fit <- function(object, ...) { # registered in zzz.R #' @rdname autoplot.model_fit -autoplot.glmnet <- function(object, min_penalty = 0, best_penalty = NULL, - top_n = 3L, ...) { +autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL, + top_n = 3L) { autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...) } diff --git a/man/autoplot.model_fit.Rd b/man/autoplot.model_fit.Rd index 9fae91c37..3843de663 100644 --- a/man/autoplot.model_fit.Rd +++ b/man/autoplot.model_fit.Rd @@ -2,15 +2,12 @@ % Please edit documentation in R/autoplot.R \name{autoplot.model_fit} \alias{autoplot.model_fit} -\alias{autoplot.workflow} \alias{autoplot.glmnet} \title{Create a ggplot for a model object} \usage{ \method{autoplot}{model_fit}(object, ...) -\method{autoplot}{workflow}(object, ...) - -\method{autoplot}{glmnet}(object, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) +\method{autoplot}{glmnet}(object, ..., min_penalty = 0, best_penalty = NULL, top_n = 3L) } \arguments{ \item{object}{A model fit object.} @@ -38,3 +35,7 @@ y-axis. For multinomial or multivariate models, the plot is faceted. This method provides a good visualization method for model results. Currently, only methods for glmnet models are implemented. } +\details{ +The \pkg{glmnet} package will need to be attached or loaded for +its \code{autoplot()} method to work correctly. +} From 56ea1c0a9bf477ad9ea1cb38ad86d6cb1cc1515f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 15 Feb 2022 21:31:49 -0500 Subject: [PATCH 11/13] A note about ggrepl --- R/autoplot.R | 3 ++- man/autoplot.model_fit.Rd | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 96f805b16..00538205a 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -8,7 +8,8 @@ #' value that should be shown in the plot. If left `NULL`, the whole data #' range is used. #' @param best_penalty A single, non-negative number that will show a vertical -#' line marker. If left `NULL`, no line is shown. +#' line marker. If left `NULL`, no line is shown. When this argument is used, +#' the \pkg{ggrepl} package is required. #' @param top_n A non-negative integer for how many model predictors to label. #' The top predictors are ranked by their absolute coefficient value. For #' multinomial or multivariate models, the `top_n` terms are selected within diff --git a/man/autoplot.model_fit.Rd b/man/autoplot.model_fit.Rd index 3843de663..1f0c5ab7e 100644 --- a/man/autoplot.model_fit.Rd +++ b/man/autoplot.model_fit.Rd @@ -20,7 +20,8 @@ value that should be shown in the plot. If left \code{NULL}, the whole data range is used.} \item{best_penalty}{A single, non-negative number that will show a vertical -line marker. If left \code{NULL}, no line is shown.} +line marker. If left \code{NULL}, no line is shown. When this argument is used, +the \pkg{ggrepl} package is required.} \item{top_n}{A non-negative integer for how many model predictors to label. The top predictors are ranked by their absolute coefficient value. For From bfbc26a863af5de33eac9595a0197484aef4ac56 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 16 Feb 2022 18:18:24 -0500 Subject: [PATCH 12/13] more consistent test files --- tests/testthat/{test-augment.R => test_augment.R} | 0 tests/testthat/{test-extract.R => test_extract.R} | 0 .../{test-proportional_hazards.R => test_proportional_hazards.R} | 0 tests/testthat/{test-survival_reg.R => test_survival_reg.R} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/testthat/{test-augment.R => test_augment.R} (100%) rename tests/testthat/{test-extract.R => test_extract.R} (100%) rename tests/testthat/{test-proportional_hazards.R => test_proportional_hazards.R} (100%) rename tests/testthat/{test-survival_reg.R => test_survival_reg.R} (100%) diff --git a/tests/testthat/test-augment.R b/tests/testthat/test_augment.R similarity index 100% rename from tests/testthat/test-augment.R rename to tests/testthat/test_augment.R diff --git a/tests/testthat/test-extract.R b/tests/testthat/test_extract.R similarity index 100% rename from tests/testthat/test-extract.R rename to tests/testthat/test_extract.R diff --git a/tests/testthat/test-proportional_hazards.R b/tests/testthat/test_proportional_hazards.R similarity index 100% rename from tests/testthat/test-proportional_hazards.R rename to tests/testthat/test_proportional_hazards.R diff --git a/tests/testthat/test-survival_reg.R b/tests/testthat/test_survival_reg.R similarity index 100% rename from tests/testthat/test-survival_reg.R rename to tests/testthat/test_survival_reg.R From 862b15b5aa09fb0d8e526c24337215618cf749bf Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 16 Feb 2022 19:00:59 -0500 Subject: [PATCH 13/13] fix nocov tags --- R/aaa.R | 4 ++-- R/autoplot.R | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/aaa.R b/R/aaa.R index 1bc0fc241..5137a54d8 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -82,7 +82,7 @@ function(results, object) { } # ------------------------------------------------------------------------------ -# nocov +# nocov start utils::globalVariables( c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', @@ -91,7 +91,7 @@ utils::globalVariables( "max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees", "sub_neighbors", ".pred_class", "x", "y", "predictor_indicators", "compute_intercept", "remove_intercept", "estimate", "term", - "call_info", "component", "component_id", "func", "tunable", "label", + "call_info", "component", "component_id", "func", "tunable", "label", "pkg", ".order", "item", "tunable" ) ) diff --git a/R/autoplot.R b/R/autoplot.R index 00538205a..f18933cdd 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -167,4 +167,4 @@ check_penalty_value <- function(x) { invisible(x) } -# nocov stop +# nocov end