Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 0.1.7.9005
Version: 0.1.7.9006
Authors@R: c(
person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")),
person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"),
Expand All @@ -21,6 +21,7 @@ Imports:
cli,
dplyr (>= 0.8.0.1),
generics (>= 0.1.0.9000),
ggplot2,
globals,
glue,
hardhat (>= 0.1.6.9001),
Expand All @@ -41,7 +42,7 @@ Suggests:
dials (>= 0.0.10.9001),
earth,
tensorflow,
ggplot2,
ggrepel,
keras,
kernlab,
kknn,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ export(C5.0_train)
export(C5_rules)
export(add_rowindex)
export(augment)
export(autoplot)
export(bag_mars)
export(bag_tree)
export(bart)
Expand Down Expand Up @@ -314,13 +315,15 @@ importFrom(generics,glance)
importFrom(generics,required_pkgs)
importFrom(generics,tidy)
importFrom(generics,varying_args)
importFrom(ggplot2,autoplot)
importFrom(glue,glue_collapse)
importFrom(hardhat,extract_fit_engine)
importFrom(hardhat,extract_parameter_dials)
importFrom(hardhat,extract_parameter_set_dials)
importFrom(hardhat,extract_spec_parsnip)
importFrom(hardhat,tune)
importFrom(magrittr,"%>%")
importFrom(purrr,"%||%")
importFrom(purrr,as_vector)
importFrom(purrr,imap)
importFrom(purrr,imap_lgl)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@

* `varying_args()` is soft-deprecated in favor of `tune_args()`.

* An `autoplot()` method was added for glmnet objects, showing the coefficient paths versus the penalty values (#642).

* parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596).

* xgboost engines now use the new `iterationrange` parameter instead of the deprecated `ntreelimit` (#656).


# parsnip 0.1.7

## Model Specification Changes
Expand Down
3 changes: 2 additions & 1 deletion R/0_imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @importFrom generics varying_args
#' @importFrom glue glue_collapse
#' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr
#' @importFrom purrr map_lgl
#' @importFrom purrr map_lgl %||%
#' @importFrom rlang abort call2 caller_env current_env enquo enquos eval_tidy
#' @importFrom rlang expr get_expr is_empty is_missing is_null is_quosure
#' @importFrom rlang is_symbolic lgl missing_arg quo_get_expr quos sym syms
Expand All @@ -16,4 +16,5 @@
#' @importFrom utils capture.output getFromNamespace globalVariables head
#' @importFrom utils methods stack
#' @importFrom vctrs vec_size vec_unique
#' @importFrom ggplot2 autoplot
NULL
6 changes: 4 additions & 2 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function(results, object) {
}

# ------------------------------------------------------------------------------
# nocov
# nocov start

utils::globalVariables(
c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group',
Expand All @@ -91,7 +91,9 @@ utils::globalVariables(
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
"compute_intercept", "remove_intercept", "estimate", "term",
"call_info", "component", "component_id", "func", "pkg", ".order", "item", "tunable")
"call_info", "component", "component_id", "func", "tunable", "label",
"pkg", ".order", "item", "tunable"
)
)

# nocov end
170 changes: 170 additions & 0 deletions R/autoplot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#' Create a ggplot for a model object
#'
#' This method provides a good visualization method for model results.
#' Currently, only methods for glmnet models are implemented.
#'
#' @param object A model fit object.
#' @param min_penalty A single, non-negative number for the smallest penalty
#' value that should be shown in the plot. If left `NULL`, the whole data
#' range is used.
#' @param best_penalty A single, non-negative number that will show a vertical
#' line marker. If left `NULL`, no line is shown. When this argument is used,
#' the \pkg{ggrepl} package is required.
#' @param top_n A non-negative integer for how many model predictors to label.
#' The top predictors are ranked by their absolute coefficient value. For
#' multinomial or multivariate models, the `top_n` terms are selected within
#' class or response, respectively.
#' @param ... For [autoplot.glmnet()], options to pass to
#' [ggrepel::geom_label_repel()]. Otherwise, this argument is ignored.
#' @return A ggplot object with penalty on the x-axis and coefficients on the
#' y-axis. For multinomial or multivariate models, the plot is faceted.
#' @details The \pkg{glmnet} package will need to be attached or loaded for
#' its `autoplot()` method to work correctly.
#'
# registered in zzz.R
autoplot.model_fit <- function(object, ...) {
autoplot(object$fit, ...)
}

# glmnet is not a formal dependency here.
# unit tests are located at https://github.com/tidymodels/extratests
# nocov start

# registered in zzz.R
#' @rdname autoplot.model_fit
autoplot.glmnet <- function(object, ..., min_penalty = 0, best_penalty = NULL,
top_n = 3L) {
autoplot_glmnet(object, min_penalty, best_penalty, top_n, ...)
}


map_glmnet_coefs <- function(x) {
coefs <- coef(x)
# If parsnip is used to fit the model, glmnet should be attached and this will
# work. If an object is loaded from a new session, they will need to load the
# package.
if (is.null(coefs)) {
rlang::abort("Please load the glmnet package before running `autoplot()`.")
}
p <- x$dim[1]
if (is.list(coefs)) {
classes <- names(coefs)
coefs <- purrr::map(coefs, reformat_coefs, p = p, penalty = x$lambda)
coefs <- purrr::map2_dfr(coefs, classes, ~ dplyr::mutate(.x, class = .y))
} else {
coefs <- reformat_coefs(coefs, p = p, penalty = x$lambda)
}
coefs
}

reformat_coefs <- function(x, p, penalty) {
x <- as.matrix(x)
num_estimates <- nrow(x)
if (num_estimates > p) {
# The intercept is first
x <- x[-(num_estimates - p),, drop = FALSE]
}
term_lab <- rownames(x)
colnames(x) <- paste(seq_along(penalty))
x <- tibble::as_tibble(x)
x$term <- term_lab
x <- tidyr::pivot_longer(x, cols = -term, names_to = "index", values_to = "estimate")
x$penalty <- rep(penalty, p)
x$index <- NULL
x
}

top_coefs <- function(x, top_n = 5) {
x %>%
dplyr::group_by(term) %>%
dplyr::arrange(term, dplyr::desc(abs(estimate))) %>%
dplyr::slice(1) %>%
dplyr::ungroup() %>%
dplyr::arrange(dplyr::desc(abs(estimate))) %>%
dplyr::slice(1:top_n)
}

autoplot_glmnet <- function(x, min_penalty = 0, best_penalty = NULL, top_n = 3L, ...) {
check_penalty_value(min_penalty)

tidy_coefs <-
map_glmnet_coefs(x) %>%
dplyr::filter(penalty >= min_penalty)

actual_min_penalty <- min(tidy_coefs$penalty)
num_terms <- length(unique(tidy_coefs$term))
top_n <- min(top_n[1], num_terms)
if (top_n < 0) {
top_n <- 0
}

has_groups <- any(names(tidy_coefs) == "class")

# Keep the large values
if (has_groups) {
label_coefs <-
tidy_coefs %>%
dplyr::group_nest(class) %>%
dplyr::mutate(data = purrr::map(data, top_coefs, top_n = top_n)) %>%
dplyr::select(class, data) %>%
tidyr::unnest(cols = data)
} else {
if (is.null(best_penalty)) {
label_coefs <- tidy_coefs %>%
top_coefs(top_n)
} else {
label_coefs <- tidy_coefs %>%
dplyr::filter(penalty > best_penalty) %>%
dplyr::filter(penalty == min(penalty)) %>%
dplyr::arrange(dplyr::desc(abs(estimate))) %>%
dplyr::slice(seq_len(top_n))
}
}

label_coefs <-
label_coefs %>%
dplyr::mutate(penalty = best_penalty %||% actual_min_penalty) %>%
dplyr::mutate(label = gsub(".pred_no_", "", term))

# plot the paths and highlight the large values
p <-
tidy_coefs %>%
ggplot2::ggplot(ggplot2::aes(x = penalty, y = estimate, group = term, col = term))

if (has_groups) {
p <- p + ggplot2::facet_wrap(~ class)
}

if (!is.null(best_penalty)) {
check_penalty_value(best_penalty)
p <- p + ggplot2::geom_vline(xintercept = best_penalty, lty = 3)
}

p <- p +
ggplot2::geom_line(alpha = .4, show.legend = FALSE) +
ggplot2::scale_x_log10()

if(top_n > 0) {
rlang::check_installed("ggrepel")
p <- p +
ggrepel::geom_label_repel(
data = label_coefs,
ggplot2::aes(y = estimate, label = label),
show.legend = FALSE,
...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly uncomfortable with the idea that dots are passed all the way through to here. I would be more comfortable with an explicit argument for repel_opts = named list() that gets passed through to here instead.

Then the dots of this autoplot() method would either be ignored or you would call rlang::check_dots_empty() to ensure the user didn't have any typos

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it uncomfortable? It seems like the most standard application of ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just because it feels a bit arbitrary. Like, why not send the dots to geom_line() instead, which also has plenty of options to tweak and is a larger part of the overall plot?

I don't feel extremely strongly about this though

)
}
p
}

check_penalty_value <- function(x) {
cl <- match.call()
arg_val <- as.character(cl$x)
if (!is.vector(x) || length(x) != 1 || !is.numeric(x) || x < 0) {
msg <- paste0("Argument '", arg_val, "' should be a single, non-negative value.")
rlang::abort(msg)
}
invisible(x)
}

# nocov end
3 changes: 3 additions & 0 deletions R/reexports.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#' @importFrom ggplot2 autoplot
#' @export
ggplot2::autoplot

#' @importFrom magrittr %>%
#' @export
Expand Down
61 changes: 3 additions & 58 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
s3_register("generics::required_pkgs", "model_fit")
s3_register("generics::required_pkgs", "model_spec")

s3_register("ggplot2::autoplot", "model_fit")
s3_register("ggplot2::autoplot", "glmnet")

# - If tune isn't installed, register the method (`packageVersion()` will error here)
# - If tune >= 0.1.6.9001 is installed, register the method
should_register_tune_args_method <- tryCatch(
Expand Down Expand Up @@ -108,61 +111,3 @@ s3_register <- function(generic, class, method = NULL) {

# nocov end



#' ## nocov start
#'
#' data_obj <- ls(pattern = "_data$")
#' data_obj <- data_obj[data_obj != "prepare_data"]
#'
#' data_names <-
#' map_dfr(
#' data_obj,
#' function(x) {
#' module <- names(get(x))
#' if (length(module) > 1) {
#' module <- table(module)
#' module <- as_tibble(module)
#' module$object <- x
#' module
#' } else
#' module <- NULL
#' module
#' }
#' )
#'
#' if(any(data_names$n > 1)) {
#' print(data_names[data_names$n > 1,])
#' rlang::abort("Some models have duplicate module names.")
#' }
#' rm(data_names)
#'
#' # ------------------------------------------------------------------------------
#'
#' engine_objects <- ls(pattern = "_engines$")
#' engine_objects <- engine_objects[engine_objects != "possible_engines"]
#'
#' get_engine_info <- function(x) {
#' y <- x
#' y <- get(y)
#' z <- stack(y)
#' z$mode <- rownames(y)
#' z$model <- gsub("_engines$", "", x)
#' z$object <- x
#' z <- z[z$values,]
#' z <- z[z$mode != "unknown",]
#' z$values <- NULL
#' names(z)[1] <- "engine"
#' z$engine <- as.character(z$engine)
#' z
#' }
#'
#' engine_info <-
#' purrr::map_df(
#' parsnip:::engine_objects,
#' get_engine_info
#' )
#'
#' rm(engine_objects)
#'
#' ## nocov end
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ reference:
- svm_rbf
- title: Infrastructure
contents:
- autoplot.model_fit
- add_rowindex
- augment.model_fit
- descriptors
Expand Down
42 changes: 42 additions & 0 deletions man/autoplot.model_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading