-
Notifications
You must be signed in to change notification settings - Fork 17
/
plot_top_loadings.R
84 lines (80 loc) · 2.65 KB
/
plot_top_loadings.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#' Plot the largest PCA component loadings from a recipe or workflow
#'
#' A plot of the `n` largest component loadings is produced.
#' @inheritParams get_loading_data
#' @param n The number of columns to plot (per component).
#' @return A `ggplot` object.
#' @examples
#' library(recipes)
#' library(parsnip)
#' library(workflows)
#' library(ggplot2)
#'
#' data("cells", package = "modeldata")
#'
#' theme_set(theme_minimal())
#'
#' ## -----------------------------------------------------------------------------
#'
#' cell_pca <-
#' recipe(class ~ ., data = cells %>% dplyr::select(-case)) %>%
#' step_center(all_predictors()) %>%
#' step_scale(all_predictors()) %>%
#' step_pca(all_predictors())
#'
#' # or when used in a workflow
#' lr_workflow <-
#' workflow() %>%
#' add_model(logistic_reg() %>% set_engine("glm")) %>%
#' add_recipe(cell_pca)
#'
#' ## -----------------------------------------------------------------------------
#'
#' cell_pca <- prep(cell_pca)
#'
#' # What were the top 10 channel 1 columns in the first three components?
#' plot_top_loadings(cell_pca, grepl("ch_1", terms) & component_number <= 3, n = 10)
#'
#' ## -----------------------------------------------------------------------------
#'
#' lr_workflow <- lr_workflow %>% fit(data = cells)
#'
#' plot_top_loadings(lr_workflow, component_number <= 3)
#'
#' @export
plot_top_loadings <- function(x, ...) {
UseMethod("plot_top_loadings")
}
#' @export
#' @rdname plot_top_loadings
plot_top_loadings.recipe <- function(x, ..., n = 4, id = NULL, type = "pca") {
comp_vals <- get_loading_data(x, ..., id = id, type = type)
comp_vals <-
comp_vals %>%
dplyr::mutate(
`Positive?` = value > 0,
abs_value = abs(value)
)%>%
dplyr::group_by(component) %>%
dplyr::top_n(n, abs_value) %>%
dplyr::ungroup() %>%
dplyr::arrange(component, abs_value) %>%
dplyr::mutate(order = dplyr::row_number())
# Tactics based on
# https://drsimonj.svbtle.com/ordering-categories-within-ggplot2-facets
ggplot2::ggplot(comp_vals, ggplot2::aes(x = abs_value, y = terms, fill = `Positive?`)) +
ggplot2::geom_col() +
ggplot2::facet_wrap( ~ component, scales = "free_y") +
ggplot2::scale_x_continuous(
breaks = comp_vals$order,
labels = comp_vals$terms,
expand = c(0,0)
) +
ggplot2::labs(y = NULL, x = "Abs. Coefficient Value")
}
#' @export
#' @rdname plot_top_loadings
plot_top_loadings.workflow <- function(x, ..., n = 4, id = NULL, type = "pca") {
x <- workflows::pull_workflow_prepped_recipe(x)
plot_top_loadings(x, ..., id = id, n = n, type = type)
}