Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Depends:
R (>= 2.10)
Imports:
dplyr,
rlang (>= 0.2.0.9001),
rlang (>= 0.3.0.1),
purrr,
utils,
tibble,
Expand All @@ -38,6 +38,4 @@ Suggests:
C50,
xgboost,
covr
Remotes:
tidyverse/rlang,
r-lib/generics

16 changes: 12 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_num,"_elnet")
S3method(predict_num,model_fit)
S3method(predict_numeric,"_elnet")
S3method(predict_numeric,model_fit)
S3method(predict_predint,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
Expand All @@ -38,12 +38,16 @@ S3method(print,multinom_reg)
S3method(print,nearest_neighbor)
S3method(print,rand_forest)
S3method(print,surv_reg)
S3method(print,svm_poly)
S3method(print,svm_rbf)
S3method(translate,boost_tree)
S3method(translate,default)
S3method(translate,mars)
S3method(translate,mlp)
S3method(translate,rand_forest)
S3method(translate,surv_reg)
S3method(translate,svm_poly)
S3method(translate,svm_rbf)
S3method(type_sum,model_fit)
S3method(type_sum,model_spec)
S3method(update,boost_tree)
Expand All @@ -55,6 +59,8 @@ S3method(update,multinom_reg)
S3method(update,nearest_neighbor)
S3method(update,rand_forest)
S3method(update,surv_reg)
S3method(update,svm_poly)
S3method(update,svm_rbf)
S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
Expand Down Expand Up @@ -92,8 +98,8 @@ export(predict_classprob)
export(predict_classprob.model_fit)
export(predict_confint)
export(predict_confint.model_fit)
export(predict_num)
export(predict_num.model_fit)
export(predict_numeric)
export(predict_numeric.model_fit)
export(predict_predint)
export(predict_predint.model_fit)
export(predict_quantile)
Expand All @@ -106,6 +112,8 @@ export(set_engine)
export(set_mode)
export(show_call)
export(surv_reg)
export(svm_poly)
export(svm_rbf)
export(translate)
export(varying)
export(varying_args)
Expand Down
10 changes: 8 additions & 2 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ xgb_pred <- function(object, newdata, ...) {
#' @export
multi_predict._xgb.Booster <-
function(object, new_data, type = NULL, trees = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

if (is.null(trees))
trees <- object$fit$nIter
trees <- sort(trees)
Expand Down Expand Up @@ -388,10 +391,10 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
nms <- names(pred)
} else {
if (type == "class") {
pred <- boost_tree_xgboost_data$classes$post(pred, object)
pred <- boost_tree_xgboost_data$class$post(pred, object)
pred <- tibble(.pred = factor(pred, levels = object$lvl))
} else {
pred <- boost_tree_xgboost_data$prob$post(pred, object)
pred <- boost_tree_xgboost_data$classprob$post(pred, object)
pred <- as_tibble(pred)
names(pred) <- paste0(".pred_", names(pred))
}
Expand Down Expand Up @@ -458,6 +461,9 @@ C5.0_train <-
#' @export
multi_predict._C5.0 <-
function(object, new_data, type = NULL, trees = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

if (is.null(trees))
trees <- min(object$fit$trials)
trees <- sort(trees)
Expand Down
16 changes: 8 additions & 8 deletions R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ boost_tree_xgboost_data <-
verbose = 0
)
),
pred = list(
numeric = list(
pre = NULL,
post = NULL,
func = c(fun = "xgb_pred"),
Expand All @@ -41,7 +41,7 @@ boost_tree_xgboost_data <-
newdata = quote(new_data)
)
),
classes = list(
class = list(
pre = NULL,
post = function(x, object) {
if (is.vector(x)) {
Expand All @@ -58,7 +58,7 @@ boost_tree_xgboost_data <-
newdata = quote(new_data)
)
),
prob = list(
classprob = list(
pre = NULL,
post = function(x, object) {
if (is.vector(x)) {
Expand Down Expand Up @@ -97,7 +97,7 @@ boost_tree_C5.0_data <-
func = c(pkg = "parsnip", fun = "C5.0_train"),
defaults = list()
),
classes = list(
class = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
Expand All @@ -106,7 +106,7 @@ boost_tree_C5.0_data <-
newdata = quote(new_data)
)
),
prob = list(
classprob = list(
pre = NULL,
post = function(x, object) {
as_tibble(x)
Expand Down Expand Up @@ -142,7 +142,7 @@ boost_tree_spark_data <-
seed = expr(sample.int(10^5, 1))
)
),
pred = list(
numeric = list(
pre = NULL,
post = format_spark_num,
func = c(pkg = "sparklyr", fun = "ml_predict"),
Expand All @@ -152,7 +152,7 @@ boost_tree_spark_data <-
dataset = quote(new_data)
)
),
classes = list(
class = list(
pre = NULL,
post = format_spark_class,
func = c(pkg = "sparklyr", fun = "ml_predict"),
Expand All @@ -162,7 +162,7 @@ boost_tree_spark_data <-
dataset = quote(new_data)
)
),
prob = list(
classprob = list(
pre = NULL,
post = format_spark_probs,
func = c(pkg = "sparklyr", fun = "ml_predict"),
Expand Down
34 changes: 27 additions & 7 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
#'
#' `linear_reg` is a way to generate a _specification_ of a model
#' before fitting and allows the model to be created using
#' different packages in R, Stan, or via Spark. The main arguments for the
#' model are:
#' different packages in R, Stan, keras, or via Spark. The main
#' arguments for the model are:
#' \itemize{
#' \item \code{penalty}: The total amount of regularization
#' in the model. Note that this must be zero for some engines .
#' in the model. Note that this must be zero for some engines.
#' \item \code{mixture}: The proportion of L1 regularization in
#' the model. Note that this will be ignored for some engines.
#' }
Expand All @@ -19,8 +19,11 @@
#' @inheritParams boost_tree
#' @param mode A single character string for the type of model.
#' The only possible value for this model is "regression".
#' @param penalty An non-negative number representing the
#' total amount of regularization (`glmnet` and `spark` only).
#' @param penalty An non-negative number representing the total
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
#' For `keras` models, this corresponds to purely L2 regularization
#' (aka weight decay) while the other models can be a combination
#' of L1 and L2 (depending on the value of `mixture`).
#' @param mixture A number between zero and one (inclusive) that
#' represents the proportion of regularization that is used for the
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
Expand All @@ -36,6 +39,7 @@
#' \item \pkg{R}: `"lm"` or `"glmnet"`
#' \item \pkg{Stan}: `"stan"`
#' \item \pkg{Spark}: `"spark"`
#' \item \pkg{keras}: `"keras"`
#' }
#'
#' @section Engine Details:
Expand All @@ -59,6 +63,10 @@
#' \pkg{spark}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
#'
#' \pkg{keras}
#'
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
#'
#' When using `glmnet` models, there is the option to pass
#' multiple values (or no values) to the `penalty` argument.
Expand Down Expand Up @@ -211,18 +219,27 @@ organize_glmnet_pred <- function(x, object) {
#' @export
predict._elnet <-
function(object, new_data, type = NULL, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
}

#' @export
predict_num._elnet <- function(object, new_data, ...) {
predict_numeric._elnet <- function(object, new_data, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_num.model_fit(object, new_data = new_data, ...)
predict_numeric.model_fit(object, new_data = new_data, ...)
}

#' @export
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}
Expand All @@ -232,6 +249,9 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
#' @export
multi_predict._elnet <-
function(object, new_data, type = NULL, penalty = NULL, ...) {
if (any(names(enquos(...)) == "newdata"))
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)

dots <- list(...)
if (is.null(penalty))
penalty <- object$fit$lambda
Expand Down
33 changes: 27 additions & 6 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ linear_reg_arg_key <- data.frame(
glmnet = c( "lambda", "alpha"),
spark = c("reg_param", "elastic_net_param"),
stan = c( NA, NA),
keras = c( "decay", NA),
stringsAsFactors = FALSE,
row.names = c("penalty", "mixture")
)

linear_reg_modes <- "regression"

linear_reg_engines <- data.frame(
lm = TRUE,
lm = TRUE,
glmnet = TRUE,
spark = TRUE,
stan = TRUE,
keras = TRUE,
row.names = c("regression")
)

Expand All @@ -30,7 +32,7 @@ linear_reg_lm_data <-
func = c(pkg = "stats", fun = "lm"),
defaults = list()
),
pred = list(
numeric = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
Expand Down Expand Up @@ -100,7 +102,7 @@ linear_reg_glmnet_data <-
family = "gaussian"
)
),
pred = list(
numeric = list(
pre = NULL,
post = organize_glmnet_pred,
func = c(fun = "predict"),
Expand Down Expand Up @@ -135,7 +137,7 @@ linear_reg_stan_data <-
family = expr(stats::gaussian)
)
),
pred = list(
numeric = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
Expand Down Expand Up @@ -224,7 +226,7 @@ linear_reg_spark_data <-
protect = c("x", "formula", "weight_col"),
func = c(pkg = "sparklyr", fun = "ml_linear_regression")
),
pred = list(
numeric = list(
pre = NULL,
post = function(results, object) {
results <- dplyr::rename(results, pred = prediction)
Expand All @@ -240,5 +242,24 @@ linear_reg_spark_data <-
)
)


linear_reg_keras_data <-
list(
libs = c("keras", "magrittr"),
fit = list(
interface = "matrix",
protect = c("x", "y"),
func = c(pkg = "parsnip", fun = "keras_mlp"),
defaults = list(hidden_units = 1, act = "linear")
),
numeric = list(
pre = NULL,
post = maybe_multivariate,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
x = quote(as.matrix(new_data))
)
)
)

Loading