Skip to content

Commit 882c192

Browse files
authored
Merge pull request #94 from topepo/quosure-passthrough-tests
Quosured argument changes
2 parents 59f0c66 + 1c6d633 commit 882c192

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1336
-943
lines changed

NAMESPACE

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,22 @@ S3method(multi_predict,"_lognet")
99
S3method(multi_predict,"_multnet")
1010
S3method(multi_predict,"_xgb.Booster")
1111
S3method(multi_predict,default)
12+
S3method(predict,"_elnet")
13+
S3method(predict,"_lognet")
1214
S3method(predict,"_multnet")
1315
S3method(predict,model_fit)
16+
S3method(predict_class,"_lognet")
1417
S3method(predict_class,model_fit)
18+
S3method(predict_classprob,"_lognet")
19+
S3method(predict_classprob,"_multnet")
1520
S3method(predict_classprob,model_fit)
1621
S3method(predict_confint,model_fit)
22+
S3method(predict_num,"_elnet")
1723
S3method(predict_num,model_fit)
1824
S3method(predict_predint,model_fit)
25+
S3method(predict_raw,"_elnet")
26+
S3method(predict_raw,"_lognet")
27+
S3method(predict_raw,"_multnet")
1928
S3method(predict_raw,model_fit)
2029
S3method(print,boost_tree)
2130
S3method(print,linear_reg)
@@ -131,6 +140,7 @@ importFrom(purrr,map_dbl)
131140
importFrom(purrr,map_df)
132141
importFrom(purrr,map_dfr)
133142
importFrom(purrr,map_lgl)
143+
importFrom(rlang,eval_tidy)
134144
importFrom(rlang,sym)
135145
importFrom(rlang,syms)
136146
importFrom(stats,.checkMFClasses)

R/arguments.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,20 @@ set_mode <- function(object, mode) {
116116
object
117117
}
118118

119+
# ------------------------------------------------------------------------------
119120

121+
#' @importFrom rlang eval_tidy
122+
#' @importFrom purrr map
123+
maybe_eval <- function(x) {
124+
# if descriptors are in `x`, eval fails
125+
y <- try(rlang::eval_tidy(x), silent = TRUE)
126+
if (inherits(y, "try-error"))
127+
y <- x
128+
y
129+
}
130+
131+
eval_args <- function(spec, ...) {
132+
spec$args <- purrr::map(spec$args, maybe_eval)
133+
spec$others <- purrr::map(spec$others, maybe_eval)
134+
spec
135+
}

R/boost_tree.R

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,24 @@ check_args.boost_tree <- function(object) {
258258

259259
# xgboost helpers --------------------------------------------------------------
260260

261-
#' Training helper for xgboost
261+
#' Boosted trees via xgboost
262262
#'
263+
#' `xgb_train` is a wrapper for `xgboost` tree-based models
264+
#' where all of the model arguments are in the main function.
265+
#'
266+
#' @param x A data frame or matrix of predictors
267+
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
268+
#' @param max_depth An integer for the maximum depth of the tree.
269+
#' @param nrounds An integer for the number of boosting iterations.
270+
#' @param eta A numeric value between zero and one to control the learning rate.
271+
#' @param colsample_bytree Subsampling proportion of columns.
272+
#' @param min_child_weight A numeric value for the minimum sum of instance
273+
#' weights needed in a child to continue to split.
274+
#' @param gamma An number for the minimum loss reduction required to make a
275+
#' further partition on a leaf node of the tree
276+
#' @param subsample Subsampling proportion of rows.
277+
#' @param ... Other options to pass to `xgb.train`.
278+
#' @return A fitted `xgboost` object.
263279
#' @export
264280
xgb_train <- function(
265281
x, y,
@@ -403,8 +419,30 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
403419

404420
# C5.0 helpers -----------------------------------------------------------------
405421

406-
#' Training helper for C5.0
422+
#' Boosted trees via C5.0
423+
#'
424+
#' `C5.0_train` is a wrapper for [C50::C5.0()] tree-based models
425+
#' where all of the model arguments are in the main function.
407426
#'
427+
#' @param x A data frame or matrix of predictors.
428+
#' @param y A factor vector with 2 or more levels
429+
#' @param trials An integer specifying the number of boosting
430+
#' iterations. A value of one indicates that a single model is
431+
#' used.
432+
#' @param weights An optional numeric vector of case weights. Note
433+
#' that the data used for the case weights will not be used as a
434+
#' splitting variable in the model (see
435+
#' \url{http://www.rulequest.com/see5-win.html#CASEWEIGHT} for
436+
#' Quinlan's notes on case weights).
437+
#' @param minCases An integer for the smallest number of samples
438+
#' that must be put in at least two of the splits.
439+
#' @param sample A value between (0, .999) that specifies the
440+
#' random proportion of the data should be used to train the model.
441+
#' By default, all the samples are used for model training. Samples
442+
#' not used for training are used to evaluate the accuracy of the
443+
#' model in the printed output.
444+
#' @param ... Other arguments to pass.
445+
#' @return A fitted C5.0 model.
408446
#' @export
409447
C5.0_train <-
410448
function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) {

R/fit.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ fit.model_spec <-
103103
cl <- match.call(expand.dots = TRUE)
104104
# Create an environment with the evaluated argument objects. This will be
105105
# used when a model call is made later.
106+
eval_env <- rlang::env()
106107

107-
eval_env <- rlang::new_environment(parent = rlang::base_env())
108108
eval_env$data <- data
109109
eval_env$formula <- formula
110110
fit_interface <-
@@ -184,7 +184,7 @@ fit_xy.model_spec <-
184184
) {
185185

186186
cl <- match.call(expand.dots = TRUE)
187-
eval_env <- rlang::new_environment(parent = rlang::base_env())
187+
eval_env <- rlang::env()
188188
eval_env$x <- x
189189
eval_env$y <- y
190190
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)

R/linear_reg.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,27 @@ organize_glmnet_pred <- function(x, object) {
226226
}
227227

228228

229+
# ------------------------------------------------------------------------------
230+
231+
#' @export
232+
predict._elnet <-
233+
function(object, new_data, type = NULL, opts = list(), ...) {
234+
object$spec <- eval_args(object$spec)
235+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
236+
}
237+
238+
#' @export
239+
predict_num._elnet <- function(object, new_data, ...) {
240+
object$spec <- eval_args(object$spec)
241+
predict_num.model_fit(object, new_data = new_data, ...)
242+
}
243+
244+
#' @export
245+
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
246+
object$spec <- eval_args(object$spec)
247+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
248+
}
249+
229250
#' @importFrom dplyr full_join as_tibble arrange
230251
#' @importFrom tidyr gather
231252
#' @export
@@ -235,6 +256,8 @@ multi_predict._elnet <-
235256
if (is.null(penalty))
236257
penalty <- object$fit$lambda
237258
dots$s <- penalty
259+
260+
object$spec <- eval_args(object$spec)
238261
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
239262
param_key <- tibble(group = colnames(pred), penalty = penalty)
240263
pred <- as_tibble(pred)

R/linear_reg_data.R

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ linear_reg_lm_data <-
3636
func = c(fun = "predict"),
3737
args =
3838
list(
39-
object = quote(object$fit),
40-
newdata = quote(new_data),
39+
object = expr(object$fit),
40+
newdata = expr(new_data),
4141
type = "response"
4242
)
4343
),
@@ -51,10 +51,10 @@ linear_reg_lm_data <-
5151
func = c(fun = "predict"),
5252
args =
5353
list(
54-
object = quote(object$fit),
55-
newdata = quote(new_data),
54+
object = expr(object$fit),
55+
newdata = expr(new_data),
5656
interval = "confidence",
57-
level = quote(level),
57+
level = expr(level),
5858
type = "response"
5959
)
6060
),
@@ -68,10 +68,10 @@ linear_reg_lm_data <-
6868
func = c(fun = "predict"),
6969
args =
7070
list(
71-
object = quote(object$fit),
72-
newdata = quote(new_data),
71+
object = expr(object$fit),
72+
newdata = expr(new_data),
7373
interval = "prediction",
74-
level = quote(level),
74+
level = expr(level),
7575
type = "response"
7676
)
7777
),
@@ -80,12 +80,14 @@ linear_reg_lm_data <-
8080
func = c(fun = "predict"),
8181
args =
8282
list(
83-
object = quote(object$fit),
84-
newdata = quote(new_data)
83+
object = expr(object$fit),
84+
newdata = expr(new_data)
8585
)
8686
)
8787
)
8888

89+
# Note: For glmnet, you will need to make model-specific predict methods.
90+
# See linear_reg.R
8991
linear_reg_glmnet_data <-
9092
list(
9193
libs = "glmnet",
@@ -104,19 +106,19 @@ linear_reg_glmnet_data <-
104106
func = c(fun = "predict"),
105107
args =
106108
list(
107-
object = quote(object$fit),
108-
newx = quote(as.matrix(new_data)),
109+
object = expr(object$fit),
110+
newx = expr(as.matrix(new_data)),
109111
type = "response",
110-
s = quote(object$spec$args$penalty)
112+
s = expr(object$spec$args$penalty)
111113
)
112114
),
113115
raw = list(
114116
pre = NULL,
115117
func = c(fun = "predict"),
116118
args =
117119
list(
118-
object = quote(object$fit),
119-
newx = quote(as.matrix(new_data))
120+
object = expr(object$fit),
121+
newx = expr(as.matrix(new_data))
120122
)
121123
)
122124
)
@@ -130,7 +132,7 @@ linear_reg_stan_data <-
130132
func = c(pkg = "rstanarm", fun = "stan_glm"),
131133
defaults =
132134
list(
133-
family = "gaussian"
135+
family = expr(stats::gaussian)
134136
)
135137
),
136138
pred = list(
@@ -139,8 +141,8 @@ linear_reg_stan_data <-
139141
func = c(fun = "predict"),
140142
args =
141143
list(
142-
object = quote(object$fit),
143-
newdata = quote(new_data)
144+
object = expr(object$fit),
145+
newdata = expr(new_data)
144146
)
145147
),
146148
confint = list(
@@ -167,8 +169,8 @@ linear_reg_stan_data <-
167169
func = c(pkg = "rstanarm", fun = "posterior_linpred"),
168170
args =
169171
list(
170-
object = quote(object$fit),
171-
newdata = quote(new_data),
172+
object = expr(object$fit),
173+
newdata = expr(new_data),
172174
transform = TRUE,
173175
seed = expr(sample.int(10^5, 1))
174176
)
@@ -197,8 +199,8 @@ linear_reg_stan_data <-
197199
func = c(pkg = "rstanarm", fun = "posterior_predict"),
198200
args =
199201
list(
200-
object = quote(object$fit),
201-
newdata = quote(new_data),
202+
object = expr(object$fit),
203+
newdata = expr(new_data),
202204
seed = expr(sample.int(10^5, 1))
203205
)
204206
),
@@ -207,8 +209,8 @@ linear_reg_stan_data <-
207209
func = c(fun = "predict"),
208210
args =
209211
list(
210-
object = quote(object$fit),
211-
newdata = quote(new_data)
212+
object = expr(object$fit),
213+
newdata = expr(new_data)
212214
)
213215
)
214216
)
@@ -232,8 +234,8 @@ linear_reg_spark_data <-
232234
func = c(pkg = "sparklyr", fun = "ml_predict"),
233235
args =
234236
list(
235-
x = quote(object$fit),
236-
dataset = quote(new_data)
237+
x = expr(object$fit),
238+
dataset = expr(new_data)
237239
)
238240
)
239241
)

R/logistic_reg.R

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,31 @@ organize_glmnet_prob <- function(x, object) {
247247

248248
# ------------------------------------------------------------------------------
249249

250+
#' @export
251+
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
252+
object$spec <- eval_args(object$spec)
253+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
254+
}
255+
256+
#' @export
257+
predict_class._lognet <- function (object, new_data, ...) {
258+
object$spec <- eval_args(object$spec)
259+
predict_class.model_fit(object, new_data = new_data, ...)
260+
}
261+
262+
#' @export
263+
predict_classprob._lognet <- function (object, new_data, ...) {
264+
object$spec <- eval_args(object$spec)
265+
predict_classprob.model_fit(object, new_data = new_data, ...)
266+
}
267+
268+
#' @export
269+
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
270+
object$spec <- eval_args(object$spec)
271+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
272+
}
273+
274+
250275
#' @importFrom dplyr full_join as_tibble arrange
251276
#' @importFrom tidyr gather
252277
#' @export
@@ -255,6 +280,7 @@ multi_predict._lognet <-
255280
dots <- list(...)
256281
if (is.null(penalty))
257282
penalty <- object$lambda
283+
dots$s <- penalty
258284

259285
if (is.null(type))
260286
type <- "class"
@@ -266,7 +292,7 @@ multi_predict._lognet <-
266292
else
267293
dots$type <- type
268294

269-
dots$s <- penalty
295+
object$spec <- eval_args(object$spec)
270296
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
271297
param_key <- tibble(group = colnames(pred), penalty = penalty)
272298
pred <- as_tibble(pred)

R/logistic_reg_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ logistic_reg_glm_data <-
9595
)
9696
)
9797

98+
# Note: For glmnet, you will need to make model-specific predict methods.
99+
# See logistic_reg.R
98100
logistic_reg_glmnet_data <-
99101
list(
100102
libs = "glmnet",

0 commit comments

Comments
 (0)