Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@ S3method(multi_predict,"_lognet")
S3method(multi_predict,"_multnet")
S3method(multi_predict,"_xgb.Booster")
S3method(multi_predict,default)
S3method(predict,"_elnet")
S3method(predict,"_lognet")
S3method(predict,"_multnet")
S3method(predict,model_fit)
S3method(predict_class,"_lognet")
S3method(predict_class,model_fit)
S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_num,"_elnet")
S3method(predict_num,model_fit)
S3method(predict_predint,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
S3method(print,boost_tree)
S3method(print,linear_reg)
Expand Down Expand Up @@ -131,6 +140,7 @@ importFrom(purrr,map_dbl)
importFrom(purrr,map_df)
importFrom(purrr,map_dfr)
importFrom(purrr,map_lgl)
importFrom(rlang,eval_tidy)
importFrom(rlang,sym)
importFrom(rlang,syms)
importFrom(stats,.checkMFClasses)
Expand Down
16 changes: 16 additions & 0 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,20 @@ set_mode <- function(object, mode) {
object
}

# ------------------------------------------------------------------------------

#' @importFrom rlang eval_tidy
#' @importFrom purrr map
maybe_eval <- function(x) {
# if descriptors are in `x`, eval fails
y <- try(rlang::eval_tidy(x), silent = TRUE)
if (inherits(y, "try-error"))
y <- x
y
}

eval_args <- function(spec, ...) {
spec$args <- purrr::map(spec$args, maybe_eval)
spec$others <- purrr::map(spec$others, maybe_eval)
spec
}
42 changes: 40 additions & 2 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,24 @@ check_args.boost_tree <- function(object) {

# xgboost helpers --------------------------------------------------------------

#' Training helper for xgboost
#' Boosted trees via xgboost
#'
#' `xgb_train` is a wrapper for `xgboost` tree-based models
#' where all of the model arguments are in the main function.
#'
#' @param x A data frame or matrix of predictors
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
#' @param max_depth An integer for the maximum depth of the tree.
#' @param nrounds An integer for the number of boosting iterations.
#' @param eta A numeric value between zero and one to control the learning rate.
#' @param colsample_bytree Subsampling proportion of columns.
#' @param min_child_weight A numeric value for the minimum sum of instance
#' weights needed in a child to continue to split.
#' @param gamma An number for the minimum loss reduction required to make a
#' further partition on a leaf node of the tree
#' @param subsample Subsampling proportion of rows.
#' @param ... Other options to pass to `xgb.train`.
#' @return A fitted `xgboost` object.
#' @export
xgb_train <- function(
x, y,
Expand Down Expand Up @@ -403,8 +419,30 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {

# C5.0 helpers -----------------------------------------------------------------

#' Training helper for C5.0
#' Boosted trees via C5.0
#'
#' `C5.0_train` is a wrapper for [C50::C5.0()] tree-based models
#' where all of the model arguments are in the main function.
#'
#' @param x A data frame or matrix of predictors.
#' @param y A factor vector with 2 or more levels
#' @param trials An integer specifying the number of boosting
#' iterations. A value of one indicates that a single model is
#' used.
#' @param weights An optional numeric vector of case weights. Note
#' that the data used for the case weights will not be used as a
#' splitting variable in the model (see
#' \url{http://www.rulequest.com/see5-win.html#CASEWEIGHT} for
#' Quinlan's notes on case weights).
#' @param minCases An integer for the smallest number of samples
#' that must be put in at least two of the splits.
#' @param sample A value between (0, .999) that specifies the
#' random proportion of the data should be used to train the model.
#' By default, all the samples are used for model training. Samples
#' not used for training are used to evaluate the accuracy of the
#' model in the printed output.
#' @param ... Other arguments to pass.
#' @return A fitted C5.0 model.
#' @export
C5.0_train <-
function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) {
Expand Down
4 changes: 2 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ fit.model_spec <-
cl <- match.call(expand.dots = TRUE)
# Create an environment with the evaluated argument objects. This will be
# used when a model call is made later.
eval_env <- rlang::env()

eval_env <- rlang::new_environment(parent = rlang::base_env())
eval_env$data <- data
eval_env$formula <- formula
fit_interface <-
Expand Down Expand Up @@ -184,7 +184,7 @@ fit_xy.model_spec <-
) {

cl <- match.call(expand.dots = TRUE)
eval_env <- rlang::new_environment(parent = rlang::base_env())
eval_env <- rlang::env()
eval_env$x <- x
eval_env$y <- y
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
Expand Down
23 changes: 23 additions & 0 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,27 @@ organize_glmnet_pred <- function(x, object) {
}


# ------------------------------------------------------------------------------

#' @export
predict._elnet <-
function(object, new_data, type = NULL, opts = list(), ...) {
object$spec <- eval_args(object$spec)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}

#' @export
predict_num._elnet <- function(object, new_data, ...) {
object$spec <- eval_args(object$spec)
predict_num.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}

#' @importFrom dplyr full_join as_tibble arrange
#' @importFrom tidyr gather
#' @export
Expand All @@ -235,6 +256,8 @@ multi_predict._elnet <-
if (is.null(penalty))
penalty <- object$fit$lambda
dots$s <- penalty

object$spec <- eval_args(object$spec)
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- as_tibble(pred)
Expand Down
54 changes: 28 additions & 26 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ linear_reg_lm_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
object = expr(object$fit),
newdata = expr(new_data),
type = "response"
)
),
Expand All @@ -51,10 +51,10 @@ linear_reg_lm_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
object = expr(object$fit),
newdata = expr(new_data),
interval = "confidence",
level = quote(level),
level = expr(level),
type = "response"
)
),
Expand All @@ -68,10 +68,10 @@ linear_reg_lm_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
object = expr(object$fit),
newdata = expr(new_data),
interval = "prediction",
level = quote(level),
level = expr(level),
type = "response"
)
),
Expand All @@ -80,12 +80,14 @@ linear_reg_lm_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data)
object = expr(object$fit),
newdata = expr(new_data)
)
)
)

# Note: For glmnet, you will need to make model-specific predict methods.
# See linear_reg.R
linear_reg_glmnet_data <-
list(
libs = "glmnet",
Expand All @@ -104,19 +106,19 @@ linear_reg_glmnet_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newx = quote(as.matrix(new_data)),
object = expr(object$fit),
newx = expr(as.matrix(new_data)),
type = "response",
s = quote(object$spec$args$penalty)
s = expr(object$spec$args$penalty)
)
),
raw = list(
pre = NULL,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newx = quote(as.matrix(new_data))
object = expr(object$fit),
newx = expr(as.matrix(new_data))
)
)
)
Expand All @@ -130,7 +132,7 @@ linear_reg_stan_data <-
func = c(pkg = "rstanarm", fun = "stan_glm"),
defaults =
list(
family = "gaussian"
family = expr(stats::gaussian)
)
),
pred = list(
Expand All @@ -139,8 +141,8 @@ linear_reg_stan_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data)
object = expr(object$fit),
newdata = expr(new_data)
)
),
confint = list(
Expand All @@ -167,8 +169,8 @@ linear_reg_stan_data <-
func = c(pkg = "rstanarm", fun = "posterior_linpred"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
object = expr(object$fit),
newdata = expr(new_data),
transform = TRUE,
seed = expr(sample.int(10^5, 1))
)
Expand Down Expand Up @@ -197,8 +199,8 @@ linear_reg_stan_data <-
func = c(pkg = "rstanarm", fun = "posterior_predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
object = expr(object$fit),
newdata = expr(new_data),
seed = expr(sample.int(10^5, 1))
)
),
Expand All @@ -207,8 +209,8 @@ linear_reg_stan_data <-
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data)
object = expr(object$fit),
newdata = expr(new_data)
)
)
)
Expand All @@ -232,8 +234,8 @@ linear_reg_spark_data <-
func = c(pkg = "sparklyr", fun = "ml_predict"),
args =
list(
x = quote(object$fit),
dataset = quote(new_data)
x = expr(object$fit),
dataset = expr(new_data)
)
)
)
Expand Down
28 changes: 27 additions & 1 deletion R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,31 @@ organize_glmnet_prob <- function(x, object) {

# ------------------------------------------------------------------------------

#' @export
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
object$spec <- eval_args(object$spec)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}

#' @export
predict_class._lognet <- function (object, new_data, ...) {
object$spec <- eval_args(object$spec)
predict_class.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_classprob._lognet <- function (object, new_data, ...) {
object$spec <- eval_args(object$spec)
predict_classprob.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}


#' @importFrom dplyr full_join as_tibble arrange
#' @importFrom tidyr gather
#' @export
Expand All @@ -255,6 +280,7 @@ multi_predict._lognet <-
dots <- list(...)
if (is.null(penalty))
penalty <- object$lambda
dots$s <- penalty

if (is.null(type))
type <- "class"
Expand All @@ -266,7 +292,7 @@ multi_predict._lognet <-
else
dots$type <- type

dots$s <- penalty
object$spec <- eval_args(object$spec)
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
param_key <- tibble(group = colnames(pred), penalty = penalty)
pred <- as_tibble(pred)
Expand Down
2 changes: 2 additions & 0 deletions R/logistic_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ logistic_reg_glm_data <-
)
)

# Note: For glmnet, you will need to make model-specific predict methods.
# See logistic_reg.R
logistic_reg_glmnet_data <-
list(
libs = "glmnet",
Expand Down
Loading