Skip to content

Commit 1f96393

Browse files
authored
Add full argument to varying_args() (#138)
Also reexport generics::varying_args()
1 parent 28b5449 commit 1f96393

File tree

6 files changed

+93
-50
lines changed

6 files changed

+93
-50
lines changed

NAMESPACE

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ export(svm_rbf)
122122
export(translate)
123123
export(varying)
124124
export(varying_args)
125-
export(varying_args.model_spec)
126-
export(varying_args.recipe)
127-
export(varying_args.step)
128125
export(xgb_train)
129126
import(rlang)
130127
importFrom(dplyr,arrange)
@@ -146,6 +143,7 @@ importFrom(dplyr,tally)
146143
importFrom(dplyr,vars)
147144
importFrom(generics,fit)
148145
importFrom(generics,fit_xy)
146+
importFrom(generics,varying_args)
149147
importFrom(glue,glue_collapse)
150148
importFrom(magrittr,"%>%")
151149
importFrom(purrr,as_vector)

NEWS.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
# parsnip 0.0.1.9000
22

3-
## Bug fixes
3+
## Other Changes
4+
5+
* `varying_args()` now has a `full` argument to control whether the full set
6+
of possible varying arguments is returned (as opposed to only the arguments
7+
that are actually varying).
8+
9+
## Bug Fixes
10+
11+
* `varying_args()` now uses the version from the `generics` package. This means
12+
that the first argument, `x`, has been renamed to `object` to align with
13+
generics.
414

515
* For the recipes step method of `varying_args()`, there is now error checking
616
to catch if a user tries to specify an argument that _cannot_ be varying as

R/varying.R

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ varying <- function() {
66
quote(varying())
77
}
88

9+
#' @importFrom generics varying_args
10+
#' @export
11+
generics::varying_args
12+
913
#' Determine varying arguments
1014
#'
1115
#' `varying_args()` takes a model specification or a recipe and returns a tibble
@@ -16,13 +20,16 @@ varying <- function() {
1620
#' or a `recipe` is used. For a `model_spec`, the first class is used. For
1721
#' a `recipe`, the unique step `id` is used.
1822
#'
19-
#' @param x A `model_spec` or a `recipe`.
23+
#' @param object A `model_spec` or a `recipe`.
24+
#' @param full A single logical. Should all possible varying parameters be
25+
#' returned? If `FALSE`, then only the parameters that
26+
#' are actually varying are returned.
2027
#'
2128
#' @param ... Not currently used.
2229
#'
2330
#' @return A tibble with columns for the parameter name (`name`), whether it
24-
#' contains _any_ varying value (`varying`), the `id` for the object, and the
25-
#' class that was used to call the method (`type`).
31+
#' contains _any_ varying value (`varying`), the `id` for the object (`id`),
32+
#' and the class that was used to call the method (`type`).
2633
#'
2734
#' @examples
2835
#'
@@ -37,6 +44,11 @@ varying <- function() {
3744
#' set_engine("ranger", sample.fraction = varying()) %>%
3845
#' varying_args()
3946
#'
47+
#' # List only the arguments that actually vary
48+
#' rand_forest() %>%
49+
#' set_engine("ranger", sample.fraction = varying()) %>%
50+
#' varying_args(full = FALSE)
51+
#'
4052
#' rand_forest() %>%
4153
#' set_engine(
4254
#' "randomForest",
@@ -45,37 +57,30 @@ varying <- function() {
4557
#' ) %>%
4658
#' varying_args()
4759
#'
48-
#' @export
49-
varying_args <- function (x, ...) {
50-
UseMethod("varying_args")
51-
}
52-
5360
#' @importFrom purrr map map_lgl
5461
#' @export
55-
#' @export varying_args.model_spec
56-
#' @rdname varying_args
57-
varying_args.model_spec <- function(x, ...) {
62+
varying_args.model_spec <- function(object, full = TRUE, ...) {
5863

5964
# use the model_spec top level class as the id
60-
id <- class(x)[1]
65+
id <- class(object)[1]
6166

62-
if (length(x$args) == 0L & length(x$eng_args) == 0L) {
67+
if (length(object$args) == 0L & length(object$eng_args) == 0L) {
6368
return(varying_tbl())
6469
}
6570

6671
# Locate varying args in spec args and engine specific args
67-
varying_args <- map_lgl(x$args, find_varying)
68-
varying_eng_args <- map_lgl(x$eng_args, find_varying)
72+
varying_args <- map_lgl(object$args, find_varying)
73+
varying_eng_args <- map_lgl(object$eng_args, find_varying)
6974

7075
res <- c(varying_args, varying_eng_args)
7176

7277
varying_tbl(
7378
name = names(res),
7479
varying = unname(res),
7580
id = id,
76-
type = "model_spec"
81+
type = "model_spec",
82+
full = full
7783
)
78-
7984
}
8085

8186
# Need to figure out a way to meld the results of varying_args with
@@ -89,66 +94,70 @@ varying_args.model_spec <- function(x, ...) {
8994

9095
#' @importFrom purrr map2_dfr map_chr
9196
#' @export
92-
#' @export varying_args.recipe
93-
#' @rdname varying_args
94-
varying_args.recipe <- function(x, ...) {
97+
#' @rdname varying_args.model_spec
98+
varying_args.recipe <- function(object, full = TRUE, ...) {
9599

96-
steps <- x$steps
100+
steps <- object$steps
97101

98102
if (length(steps) == 0L) {
99103
return(varying_tbl())
100104
}
101105

102-
map_dfr(x$steps, varying_args)
106+
map_dfr(object$steps, varying_args, full = full)
103107
}
104108

105109
#' @importFrom purrr map map_lgl
106110
#' @export
107-
#' @export varying_args.step
108-
#' @rdname varying_args
109-
varying_args.step <- function(x, ...) {
111+
#' @rdname varying_args.model_spec
112+
varying_args.step <- function(object, full = TRUE, ...) {
110113

111114
# Unique step id
112-
id <- x$id
115+
id <- object$id
113116

114117
# Grab the step class before the subset, as that removes the class
115-
step_type <- class(x)[1]
118+
step_type <- class(object)[1]
116119

117120
# Remove NULL argument steps. These are reserved
118121
# for deprecated args or those set at prep() time.
119-
x <- x[!map_lgl(x, is.null)]
122+
object <- object[!map_lgl(object, is.null)]
120123

121-
res <- map_lgl(x, find_varying)
124+
res <- map_lgl(object, find_varying)
122125

123126
# ensure the user didn't specify a non-varying argument as varying()
124127
validate_only_allowed_step_args(res, step_type)
125128

126129
# remove the non-varying arguments as they are not important
127-
res <- res[!(names(x) %in% non_varying_step_arguments)]
130+
res <- res[!(names(object) %in% non_varying_step_arguments)]
128131

129132
varying_tbl(
130133
name = names(res),
131134
varying = unname(res),
132135
id = id,
133-
type = "step"
136+
type = "step",
137+
full = full
134138
)
135-
136139
}
137140

138141
# useful for standardization and for creating a 0 row varying tbl
139142
# (i.e. for when there are no steps in a recipe)
140143
varying_tbl <- function(name = character(),
141144
varying = logical(),
142145
id = character(),
143-
type = character()) {
146+
type = character(),
147+
full = FALSE) {
144148

145-
tibble(
149+
vry_tbl <- tibble(
146150
name = name,
147151
varying = varying,
148152
id = id,
149153
type = type
150154
)
151155

156+
if (!full) {
157+
vry_tbl <- vry_tbl[vry_tbl$varying,]
158+
}
159+
160+
vry_tbl
152161
}
153162

154163
validate_only_allowed_step_args <- function(x, step_type) {

man/reexports.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/varying_args.Rd renamed to man/varying_args.model_spec.Rd

Lines changed: 16 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_varying.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,22 @@ test_that("recipe steps with non-varying args error if specified as varying()",
175175
"The following argument for a recipe step of type 'step_center' is not allowed to vary: 'skip'."
176176
)
177177
})
178+
179+
test_that("`full = FALSE` returns only varying arguments", {
180+
181+
x_spec <- rand_forest(min_n = varying()) %>%
182+
set_engine("ranger", sample.fraction = varying())
183+
184+
x_rec <- rec_1
185+
186+
expect_equal(
187+
varying_args(x_spec, full = FALSE)$name,
188+
c("min_n", "sample.fraction")
189+
)
190+
191+
expect_equal(
192+
varying_args(x_rec, full = FALSE)$name,
193+
c("K", "num")
194+
)
195+
196+
})

0 commit comments

Comments
 (0)