diff --git a/DESCRIPTION b/DESCRIPTION index 268c70268..2cf450ba3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,7 +17,7 @@ Depends: R (>= 2.10) Imports: dplyr, - rlang (>= 0.2.0.9001), + rlang (>= 0.3.0.1), purrr, utils, tibble, @@ -38,6 +38,4 @@ Suggests: C50, xgboost, covr -Remotes: - tidyverse/rlang, - r-lib/generics + diff --git a/NAMESPACE b/NAMESPACE index 7d4305157..f067c2785 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -19,8 +19,8 @@ S3method(predict_classprob,"_lognet") S3method(predict_classprob,"_multnet") S3method(predict_classprob,model_fit) S3method(predict_confint,model_fit) -S3method(predict_num,"_elnet") -S3method(predict_num,model_fit) +S3method(predict_numeric,"_elnet") +S3method(predict_numeric,model_fit) S3method(predict_predint,model_fit) S3method(predict_quantile,model_fit) S3method(predict_raw,"_elnet") @@ -38,12 +38,16 @@ S3method(print,multinom_reg) S3method(print,nearest_neighbor) S3method(print,rand_forest) S3method(print,surv_reg) +S3method(print,svm_poly) +S3method(print,svm_rbf) S3method(translate,boost_tree) S3method(translate,default) S3method(translate,mars) S3method(translate,mlp) S3method(translate,rand_forest) S3method(translate,surv_reg) +S3method(translate,svm_poly) +S3method(translate,svm_rbf) S3method(type_sum,model_fit) S3method(type_sum,model_spec) S3method(update,boost_tree) @@ -55,6 +59,8 @@ S3method(update,multinom_reg) S3method(update,nearest_neighbor) S3method(update,rand_forest) S3method(update,surv_reg) +S3method(update,svm_poly) +S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) @@ -92,8 +98,8 @@ export(predict_classprob) export(predict_classprob.model_fit) export(predict_confint) export(predict_confint.model_fit) -export(predict_num) -export(predict_num.model_fit) +export(predict_numeric) +export(predict_numeric.model_fit) export(predict_predint) export(predict_predint.model_fit) export(predict_quantile) @@ -106,6 +112,8 @@ export(set_engine) export(set_mode) export(show_call) export(surv_reg) +export(svm_poly) +export(svm_rbf) export(translate) export(varying) export(varying_args) diff --git a/R/boost_tree.R b/R/boost_tree.R index f196d4e8e..8c2614547 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -359,6 +359,9 @@ xgb_pred <- function(object, newdata, ...) { #' @export multi_predict._xgb.Booster <- function(object, new_data, type = NULL, trees = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is.null(trees)) trees <- object$fit$nIter trees <- sort(trees) @@ -388,10 +391,10 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { nms <- names(pred) } else { if (type == "class") { - pred <- boost_tree_xgboost_data$classes$post(pred, object) + pred <- boost_tree_xgboost_data$class$post(pred, object) pred <- tibble(.pred = factor(pred, levels = object$lvl)) } else { - pred <- boost_tree_xgboost_data$prob$post(pred, object) + pred <- boost_tree_xgboost_data$classprob$post(pred, object) pred <- as_tibble(pred) names(pred) <- paste0(".pred_", names(pred)) } @@ -458,6 +461,9 @@ C5.0_train <- #' @export multi_predict._C5.0 <- function(object, new_data, type = NULL, trees = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is.null(trees)) trees <- min(object$fit$trials) trees <- sort(trees) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 206b78e20..559698511 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -31,7 +31,7 @@ boost_tree_xgboost_data <- verbose = 0 ) ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "xgb_pred"), @@ -41,7 +41,7 @@ boost_tree_xgboost_data <- newdata = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { if (is.vector(x)) { @@ -58,7 +58,7 @@ boost_tree_xgboost_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { if (is.vector(x)) { @@ -97,7 +97,7 @@ boost_tree_C5.0_data <- func = c(pkg = "parsnip", fun = "C5.0_train"), defaults = list() ), - classes = list( + class = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -106,7 +106,7 @@ boost_tree_C5.0_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { as_tibble(x) @@ -142,7 +142,7 @@ boost_tree_spark_data <- seed = expr(sample.int(10^5, 1)) ) ), - pred = list( + numeric = list( pre = NULL, post = format_spark_num, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -152,7 +152,7 @@ boost_tree_spark_data <- dataset = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -162,7 +162,7 @@ boost_tree_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/linear_reg.R b/R/linear_reg.R index e0805d288..b4b716212 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -2,11 +2,11 @@ #' #' `linear_reg` is a way to generate a _specification_ of a model #' before fitting and allows the model to be created using -#' different packages in R, Stan, or via Spark. The main arguments for the -#' model are: +#' different packages in R, Stan, keras, or via Spark. The main +#' arguments for the model are: #' \itemize{ #' \item \code{penalty}: The total amount of regularization -#' in the model. Note that this must be zero for some engines . +#' in the model. Note that this must be zero for some engines. #' \item \code{mixture}: The proportion of L1 regularization in #' the model. Note that this will be ignored for some engines. #' } @@ -19,8 +19,11 @@ #' @inheritParams boost_tree #' @param mode A single character string for the type of model. #' The only possible value for this model is "regression". -#' @param penalty An non-negative number representing the -#' total amount of regularization (`glmnet` and `spark` only). +#' @param penalty An non-negative number representing the total +#' amount of regularization (`glmnet`, `keras`, and `spark` only). +#' For `keras` models, this corresponds to purely L2 regularization +#' (aka weight decay) while the other models can be a combination +#' of L1 and L2 (depending on the value of `mixture`). #' @param mixture A number between zero and one (inclusive) that #' represents the proportion of regularization that is used for the #' L2 penalty (i.e. weight decay, or ridge regression) versus L1 @@ -36,6 +39,7 @@ #' \item \pkg{R}: `"lm"` or `"glmnet"` #' \item \pkg{Stan}: `"stan"` #' \item \pkg{Spark}: `"spark"` +#' \item \pkg{keras}: `"keras"` #' } #' #' @section Engine Details: @@ -59,6 +63,10 @@ #' \pkg{spark} #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} +#' +#' \pkg{keras} +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. @@ -211,18 +219,27 @@ organize_glmnet_pred <- function(x, object) { #' @export predict._elnet <- function(object, new_data, type = NULL, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } #' @export -predict_num._elnet <- function(object, new_data, ...) { +predict_numeric._elnet <- function(object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) - predict_num.model_fit(object, new_data = new_data, ...) + predict_numeric.model_fit(object, new_data = new_data, ...) } #' @export predict_raw._elnet <- function(object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -232,6 +249,9 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) { #' @export multi_predict._elnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + dots <- list(...) if (is.null(penalty)) penalty <- object$fit$lambda diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 57aebfd02..6225d1023 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -4,6 +4,7 @@ linear_reg_arg_key <- data.frame( glmnet = c( "lambda", "alpha"), spark = c("reg_param", "elastic_net_param"), stan = c( NA, NA), + keras = c( "decay", NA), stringsAsFactors = FALSE, row.names = c("penalty", "mixture") ) @@ -11,10 +12,11 @@ linear_reg_arg_key <- data.frame( linear_reg_modes <- "regression" linear_reg_engines <- data.frame( - lm = TRUE, + lm = TRUE, glmnet = TRUE, spark = TRUE, stan = TRUE, + keras = TRUE, row.names = c("regression") ) @@ -30,7 +32,7 @@ linear_reg_lm_data <- func = c(pkg = "stats", fun = "lm"), defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -100,7 +102,7 @@ linear_reg_glmnet_data <- family = "gaussian" ) ), - pred = list( + numeric = list( pre = NULL, post = organize_glmnet_pred, func = c(fun = "predict"), @@ -135,7 +137,7 @@ linear_reg_stan_data <- family = expr(stats::gaussian) ) ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -224,7 +226,7 @@ linear_reg_spark_data <- protect = c("x", "formula", "weight_col"), func = c(pkg = "sparklyr", fun = "ml_linear_regression") ), - pred = list( + numeric = list( pre = NULL, post = function(results, object) { results <- dplyr::rename(results, pred = prediction) @@ -240,5 +242,24 @@ linear_reg_spark_data <- ) ) - +linear_reg_keras_data <- + list( + libs = c("keras", "magrittr"), + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ), + numeric = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) + ) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index a0d67f0c1..af2382243 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -2,8 +2,8 @@ #' #' `logistic_reg` is a way to generate a _specification_ of a model #' before fitting and allows the model to be created using -#' different packages in R, Stan, or via Spark. The main arguments for the -#' model are: +#' different packages in R, Stan, keras, or via Spark. The main +#' arguments for the model are: #' \itemize{ #' \item \code{penalty}: The total amount of regularization #' in the model. Note that this must be zero for some engines. @@ -19,8 +19,11 @@ #' @inheritParams boost_tree #' @param mode A single character string for the type of model. #' The only possible value for this model is "classification". -#' @param penalty An non-negative number representing the -#' total amount of regularization (`glmnet` and `spark` only). +#' @param penalty An non-negative number representing the total +#' amount of regularization (`glmnet`, `keras`, and `spark` only). +#' For `keras` models, this corresponds to purely L2 regularization +#' (aka weight decay) while the other models can be a combination +#' of L1 and L2 (depending on the value of `mixture`). #' @param mixture A number between zero and one (inclusive) that #' represents the proportion of regularization that is used for the #' L2 penalty (i.e. weight decay, or ridge regression) versus L1 @@ -34,6 +37,7 @@ #' \item \pkg{R}: `"glm"` or `"glmnet"` #' \item \pkg{Stan}: `"stan"` #' \item \pkg{Spark}: `"spark"` +#' \item \pkg{keras}: `"keras"` #' } #' #' @section Engine Details: @@ -58,6 +62,10 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} #' +#' \pkg{keras} +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} +#' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. #' This can have an effect on the model object results. When using @@ -230,24 +238,36 @@ organize_glmnet_prob <- function(x, object) { #' @export predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } #' @export predict_class._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_class.model_fit(object, new_data = new_data, ...) } #' @export predict_classprob._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_classprob.model_fit(object, new_data = new_data, ...) } #' @export predict_raw._lognet <- function (object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -258,6 +278,9 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) { #' @export multi_predict._lognet <- function(object, new_data, type = NULL, penalty = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + dots <- list(...) if (is.null(penalty)) penalty <- object$lambda diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 972add371..a5aef8bfb 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -4,6 +4,7 @@ logistic_reg_arg_key <- data.frame( glmnet = c( "lambda", "alpha"), spark = c("reg_param", "elastic_net_param"), stan = c( NA, NA), + keras = c( "decay", NA), stringsAsFactors = FALSE, row.names = c("penalty", "mixture") ) @@ -15,6 +16,7 @@ logistic_reg_engines <- data.frame( glmnet = TRUE, spark = TRUE, stan = TRUE, + keras = TRUE, row.names = c("classification") ) @@ -33,7 +35,7 @@ logistic_reg_glm_data <- family = expr(stats::binomial) ) ), - classes = list( + class = list( pre = NULL, post = prob_to_class_2, func = c(fun = "predict"), @@ -44,7 +46,7 @@ logistic_reg_glm_data <- type = "response" ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- tibble(v1 = 1 - x, v2 = x) @@ -109,7 +111,7 @@ logistic_reg_glmnet_data <- family = "binomial" ) ), - classes = list( + class = list( pre = NULL, post = organize_glmnet_class, func = c(fun = "predict"), @@ -121,7 +123,7 @@ logistic_reg_glmnet_data <- s = quote(object$spec$args$penalty) ) ), - prob = list( + classprob = list( pre = NULL, post = organize_glmnet_prob, func = c(fun = "predict"), @@ -156,7 +158,7 @@ logistic_reg_stan_data <- family = expr(stats::binomial) ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { x <- object$fit$family$linkinv(x) @@ -170,7 +172,7 @@ logistic_reg_stan_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- object$fit$family$linkinv(x) @@ -268,7 +270,7 @@ logistic_reg_spark_data <- family = "binomial" ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -278,7 +280,7 @@ logistic_reg_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -290,3 +292,39 @@ logistic_reg_spark_data <- ) ) +logistic_reg_keras_data <- + list( + libs = c("keras", "magrittr"), + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ), + class = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ), + classprob = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) + ) \ No newline at end of file diff --git a/R/mars.R b/R/mars.R index 7835eb05b..f1f1c8151 100644 --- a/R/mars.R +++ b/R/mars.R @@ -206,6 +206,9 @@ earth_reg_updater <- function(num, object, new_data, ...) { #' @export multi_predict._earth <- function(object, new_data, type = NULL, num_terms = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is.null(num_terms)) num_terms <- object$fit$selected.terms[-1] diff --git a/R/mars_data.R b/R/mars_data.R index 83addb849..0c1076e68 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -23,7 +23,7 @@ mars_earth_data <- func = c(pkg = "earth", fun = "earth"), defaults = list(keepxy = TRUE) ), - pred = list( + numeric = list( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), @@ -34,7 +34,7 @@ mars_earth_data <- type = "response" ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { x <- ifelse(x[,1] >= 0.5, object$lvl[2], object$lvl[1]) @@ -48,7 +48,7 @@ mars_earth_data <- type = "response" ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- x[,1] diff --git a/R/mlp_data.R b/R/mlp_data.R index 5e5ccd3f8..4eac89dcc 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -25,7 +25,7 @@ mlp_keras_data <- func = c(pkg = "parsnip", fun = "keras_mlp"), defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), @@ -35,7 +35,7 @@ mlp_keras_data <- x = quote(as.matrix(new_data)) ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { object$lvl[x + 1] @@ -47,7 +47,7 @@ mlp_keras_data <- x = quote(as.matrix(new_data)) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- as_tibble(x) @@ -72,6 +72,17 @@ mlp_keras_data <- ) ) + +nnet_softmax <- function(results, object) { + if (ncol(results) == 1) + results <- cbind(1 - results, results) + + results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) + results <- as_tibble(t(results)) + names(results) <- paste0(".pred_", object$lvl) + results +} + mlp_nnet_data <- list( libs = "nnet", @@ -81,7 +92,7 @@ mlp_nnet_data <- func = c(pkg = "nnet", fun = "nnet"), defaults = list(trace = FALSE) ), - pred = list( + numeric = list( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), @@ -92,7 +103,7 @@ mlp_nnet_data <- type = "raw" ) ), - classes = list( + class = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -103,6 +114,17 @@ mlp_nnet_data <- type = "class" ) ), + classprob = list( + pre = NULL, + post = nnet_softmax, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ), raw = list( pre = NULL, func = c(fun = "predict"), @@ -114,6 +136,7 @@ mlp_nnet_data <- ) ) + # ------------------------------------------------------------------------------ # keras wrapper for feed-forward nnet diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 6f6a41b43..5ac4393cd 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -2,7 +2,7 @@ #' #' `multinom_reg` is a way to generate a _specification_ of a model #' before fitting and allows the model to be created using -#' different packages in R or Spark. The main arguments for the +#' different packages in R, keras, or Spark. The main arguments for the #' model are: #' \itemize{ #' \item \code{penalty}: The total amount of regularization @@ -19,8 +19,11 @@ #' @inheritParams boost_tree #' @param mode A single character string for the type of model. #' The only possible value for this model is "classification". -#' @param penalty An non-negative number representing the -#' total amount of regularization. +#' @param penalty An non-negative number representing the total +#' amount of regularization (`glmnet`, `keras`, and `spark` only). +#' For `keras` models, this corresponds to purely L2 regularization +#' (aka weight decay) while the other models can be a combination +#' of L1 and L2 (depending on the value of `mixture`). #' @param mixture A number between zero and one (inclusive) that #' represents the proportion of regularization that is used for the #' L2 penalty (i.e. weight decay, or ridge regression) versus L1 @@ -33,6 +36,7 @@ #' \itemize{ #' \item \pkg{R}: `"glmnet"` #' \item \pkg{Stan}: `"stan"` +#' \item \pkg{keras}: `"keras"` #' } #' #' @section Engine Details: @@ -49,6 +53,10 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} #' +#' \pkg{keras} +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} +#' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. #' This can have an effect on the model object results. When using @@ -236,6 +244,9 @@ predict._multnet <- #' @export multi_predict._multnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is_quosure(penalty)) penalty <- eval_tidy(penalty) diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 921e6a0dc..186291003 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -2,6 +2,7 @@ multinom_reg_arg_key <- data.frame( glmnet = c( "lambda", "alpha"), spark = c("reg_param", "elastic_net_param"), + keras = c( "decay", NA), stringsAsFactors = FALSE, row.names = c("penalty", "mixture") ) @@ -11,6 +12,7 @@ multinom_reg_modes <- "classification" multinom_reg_engines <- data.frame( glmnet = TRUE, spark = TRUE, + keras = TRUE, row.names = c("classification") ) @@ -28,7 +30,7 @@ multinom_reg_glmnet_data <- family = "multinomial" ) ), - classes = list( + class = list( pre = check_glmnet_lambda, post = organize_multnet_class, func = c(fun = "predict"), @@ -40,7 +42,7 @@ multinom_reg_glmnet_data <- s = quote(object$spec$args$penalty) ) ), - prob = list( + classprob = list( pre = check_glmnet_lambda, post = organize_multnet_prob, func = c(fun = "predict"), @@ -75,7 +77,7 @@ multinom_reg_spark_data <- family = "multinomial" ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -85,7 +87,7 @@ multinom_reg_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -98,3 +100,39 @@ multinom_reg_spark_data <- ) +multinom_reg_keras_data <- + list( + libs = c("keras", "magrittr"), + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ), + class = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ), + classprob = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) + ) diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index c6106561c..0191d8614 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -22,8 +22,8 @@ nearest_neighbor_kknn_data <- func = c(pkg = "kknn", fun = "train.kknn"), defaults = list() ), - pred = list( - # seems unnecessary here as the predict_num catches it based on the + numeric = list( + # seems unnecessary here as the predict_numeric catches it based on the # model mode pre = function(x, object) { if (object$fit$response != "continuous") { @@ -42,7 +42,7 @@ nearest_neighbor_kknn_data <- type = "raw" ) ), - classes = list( + class = list( pre = function(x, object) { if (!(object$fit$response %in% c("ordinal", "nominal"))) { stop("`kknn` model does not appear to use class predictions. Was ", @@ -60,7 +60,7 @@ nearest_neighbor_kknn_data <- type = "raw" ) ), - prob = list( + classprob = list( pre = function(x, object) { if (!(object$fit$response %in% c("ordinal", "nominal"))) { stop("`kknn` model does not appear to use class predictions. Was ", diff --git a/R/predict.R b/R/predict.R index ea7ea7149..947b8ed72 100644 --- a/R/predict.R +++ b/R/predict.R @@ -91,12 +91,15 @@ #' @export predict.model_fit #' @export predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + type <- check_pred_type(object, type) if (type != "raw" && length(opts) > 0) warning("`opts` is only used with `type = 'raw'` and was ignored.") res <- switch( type, - numeric = predict_num(object = object, new_data = new_data, ...), + numeric = predict_numeric(object = object, new_data = new_data, ...), class = predict_class(object = object, new_data = new_data, ...), prob = predict_classprob(object = object, new_data = new_data, ...), conf_int = predict_confint(object = object, new_data = new_data, ...), diff --git a/R/predict_class.R b/R/predict_class.R index 67e15de95..c586b1039 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -13,23 +13,23 @@ predict_class.model_fit <- function (object, new_data, ...) { stop("`predict.model_fit` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "classes")) + if (!any(names(object$spec$method) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$classes$pre)) - new_data <- object$spec$method$classes$pre(new_data, object) + if (!is.null(object$spec$method$class$pre)) + new_data <- object$spec$method$class$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$classes) + pred_call <- make_pred_call(object$spec$method$class) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$classes$post)) { - res <- object$spec$method$classes$post(res, object) + if(!is.null(object$spec$method$class$post)) { + res <- object$spec$method$class$post(res, object) } # coerce levels to those in `object` diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 113015fe8..8f2f79b8d 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -10,23 +10,23 @@ predict_classprob.model_fit <- function (object, new_data, ...) { stop("`predict.model_fit` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "prob")) + if (!any(names(object$spec$method) == "classprob")) stop("No class probability module defined for this model.", call. = FALSE) new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$prob$pre)) - new_data <- object$spec$method$prob$pre(new_data, object) + if (!is.null(object$spec$method$classprob$pre)) + new_data <- object$spec$method$classprob$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$prob) + pred_call <- make_pred_call(object$spec$method$classprob) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$prob$post)) { - res <- object$spec$method$prob$post(res, object) + if(!is.null(object$spec$method$classprob$post)) { + res <- object$spec$method$classprob$post(res, object) } # check and sort names diff --git a/R/predict_num.R b/R/predict_numeric.R similarity index 50% rename from R/predict_num.R rename to R/predict_numeric.R index 243a73890..5b9750146 100644 --- a/R/predict_num.R +++ b/R/predict_numeric.R @@ -1,33 +1,33 @@ #' @keywords internal #' @rdname other_predict #' @inheritParams predict.model_fit -#' @method predict_num model_fit -#' @export predict_num.model_fit +#' @method predict_numeric model_fit +#' @export predict_numeric.model_fit #' @export -# TODO add ... -predict_num.model_fit <- function (object, new_data, ...) { + +predict_numeric.model_fit <- function (object, new_data, ...) { if (object$spec$mode != "regression") - stop("`predict_num` is for predicting numeric outcomes. ", + stop("`predict_numeric` is for predicting numeric outcomes. ", "Use `predict_class` or `predict_prob` for ", "classification models.", call. = FALSE) - if (!any(names(object$spec$method) == "pred")) + if (!any(names(object$spec$method) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$pre)) - new_data <- object$spec$method$pred$pre(new_data, object) + if (!is.null(object$spec$method$numeric$pre)) + new_data <- object$spec$method$numeric$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$pred) + pred_call <- make_pred_call(object$spec$method$numeric) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$pred$post)) { - res <- object$spec$method$pred$post(res, object) + if (!is.null(object$spec$method$numeric$post)) { + res <- object$spec$method$numeric$post(res, object) } if (is.vector(res)) { @@ -43,6 +43,6 @@ predict_num.model_fit <- function (object, new_data, ...) { #' @export #' @keywords internal #' @rdname other_predict -#' @inheritParams predict_num.model_fit -predict_num <- function (object, ...) - UseMethod("predict_num") +#' @inheritParams predict_numeric.model_fit +predict_numeric <- function (object, ...) + UseMethod("predict_numeric") diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 3db4a439a..65eb84864 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -110,7 +110,7 @@ rand_forest_ranger_data <- seed = expr(sample.int(10^5, 1)) ) ), - pred = list( + numeric = list( pre = NULL, post = function(results, object) results$predictions, func = c(fun = "predict"), @@ -123,7 +123,7 @@ rand_forest_ranger_data <- verbose = FALSE ) ), - classes = list( + class = list( pre = NULL, post = ranger_class_pred, func = c(fun = "predict"), @@ -136,7 +136,7 @@ rand_forest_ranger_data <- verbose = FALSE ) ), - prob = list( + classprob = list( pre = function(x, object) { if (object$fit$forest$treetype != "Probability estimation") stop("`ranger` model does not appear to use class probabilities. Was ", @@ -190,7 +190,7 @@ rand_forest_randomForest_data <- defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -200,7 +200,7 @@ rand_forest_randomForest_data <- newdata = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -210,7 +210,7 @@ rand_forest_randomForest_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { as_tibble(as.data.frame(x)) @@ -247,7 +247,7 @@ rand_forest_spark_data <- seed = expr(sample.int(10^5, 1)) ) ), - pred = list( + numeric = list( pre = NULL, post = format_spark_num, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -257,7 +257,7 @@ rand_forest_spark_data <- dataset = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -267,7 +267,7 @@ rand_forest_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 43f55cecb..3b030717b 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -26,7 +26,7 @@ surv_reg_flexsurv_data <- func = c(pkg = "flexsurv", fun = "flexsurvreg"), defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = flexsurv_mean, func = c(fun = "summary"), @@ -62,7 +62,7 @@ surv_reg_survreg_data <- func = c(pkg = "survival", fun = "survreg"), defaults = list(model = TRUE) ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -101,7 +101,7 @@ surv_reg_survreg_data <- # seed = expr(sample.int(10^5, 1)) # ) # ), -# pred = list( +# numeric = list( # pre = NULL, # post = function(results, object) { # tibble::as_tibble(results) %>% diff --git a/R/svm_poly.R b/R/svm_poly.R new file mode 100644 index 000000000..adb6d043d --- /dev/null +++ b/R/svm_poly.R @@ -0,0 +1,196 @@ +#' General interface for polynomial support vector machines +#' +#' `svm_poly` is a way to generate a _specification_ of a model +#' before fitting and allows the model to be created using +#' different packages in R or via Spark. The main arguments for the +#' model are: +#' \itemize{ +#' \item \code{cost}: The cost of predicting a sample within or on the +#' wrong side of the margin. +#' \item \code{degree}: The polynomial degree. +#' \item \code{scale_factor}: A scaling factor for the kernel. +#' \item \code{margin}: The epsilon in the SVM insensitive loss function +#' (regression only) +#' } +#' These arguments are converted to their specific names at the +#' time that the model is fit. Other options and argument can be +#' set using `set_engine`. If left to their defaults +#' here (`NULL`), the values are taken from the underlying model +#' functions. If parameters need to be modified, `update` can be used +#' in lieu of recreating the object from scratch. +#' +#' @inheritParams boost_tree +#' @param mode A single character string for the type of model. +#' Possible values for this model are "unknown", "regression", or +#' "classification". +#' @param cost A positive number for the cost of predicting a sample within +#' or on the wrong side of the margin +#' @param degree A positive number for polynomial degree. +#' @param scale_factor A positive number for the polynomial scaling factor. +#' @param margin A positive number for the epsilon in the SVM insensitive +#' loss function (regression only) +#' @details +#' The model can be created using the `fit()` function using the +#' following _engines_: +#' \itemize{ +#' \item \pkg{R}: `"kernlab"` +#' } +#' +#' @section Engine Details: +#' +#' Engines may have pre-set default arguments when executing the +#' model fit call. For this type of +#' model, the template of the fit calls are:: +#' +#' \pkg{kernlab} classification +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} +#' +#' \pkg{kernlab} regression +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} +#' +#' @importFrom purrr map_lgl +#' @seealso [varying()], [fit()] +#' @examples +#' svm_poly(mode = "classification", degree = 1.2) +#' # Parameters can be represented by a placeholder: +#' svm_poly(mode = "regression", cost = varying()) +#' @export + +svm_poly <- + function(mode = "unknown", + cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL) { + + args <- list( + cost = enquo(cost), + degree = enquo(degree), + scale_factor = enquo(scale_factor), + margin = enquo(margin) + ) + + new_model_spec( + "svm_poly", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + } + +#' @export +print.svm_poly <- function(x, ...) { + cat("Polynomial Support Vector Machine Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if(!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + invisible(x) +} + +# ------------------------------------------------------------------------------ + +#' @export +#' @inheritParams update.boost_tree +#' @param object A polynomial SVM model specification. +#' @examples +#' model <- svm_poly(cost = 10, scale_factor = 0.1) +#' model +#' update(model, cost = 1) +#' update(model, cost = 1, fresh = TRUE) +#' @method update svm_poly +#' @rdname svm_poly +#' @export +update.svm_poly <- + function(object, + cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, + fresh = FALSE, + ...) { + update_dot_check(...) + + args <- list( + cost = enquo(cost), + degree = enquo(degree), + scale_factor = enquo(scale_factor), + margin = enquo(margin) + ) + + if (fresh) { + object$args <- args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + } + + new_model_spec( + "svm_poly", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) + } + +# ------------------------------------------------------------------------------ + +#' @export +translate.svm_poly <- function(x, engine = x$engine, ...) { + x <- translate.default(x, engine = engine, ...) + + # slightly cleaner code using + arg_vals <- x$method$fit$args + arg_names <- names(arg_vals) + + # add checks to error trap or change things for this method + if (x$engine == "kernlab") { + + # unless otherwise specified, classification models predict probabilities + if (x$mode == "classification" && !any(arg_names == "prob.model")) + arg_vals$prob.model <- TRUE + if (x$mode == "classification" && any(arg_names == "epsilon")) + arg_vals$epsilon <- NULL + + # convert degree and scale to a `kpar` argument. + if (any(arg_names %in% c("degree", "scale", "offset"))) { + kpar <- expr(list()) + if (any(arg_names == "degree")) { + kpar$degree <- arg_vals$degree + arg_vals$degree <- NULL + } + if (any(arg_names == "scale")) { + kpar$scale <- arg_vals$scale + arg_vals$scale <- NULL + } + if (any(arg_names == "offset")) { + kpar$offset <- arg_vals$offset + arg_vals$offset <- NULL + } + arg_vals$kpar <- kpar + } + + } + x$method$fit$args <- arg_vals + + # worried about people using this to modify the specification + x +} + +# ------------------------------------------------------------------------------ + +check_args.svm_poly <- function(object) { + invisible(object) +} + +# ------------------------------------------------------------------------------ + +svm_reg_post <- function(results, object) { + results[,1] +} + diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R new file mode 100644 index 000000000..04b0bc55e --- /dev/null +++ b/R/svm_poly_data.R @@ -0,0 +1,69 @@ +svm_poly_arg_key <- data.frame( + kernlab = c( "C", "degree", "scale", "epsilon"), + row.names = c("cost", "degree", "scale_factor", "margin"), + stringsAsFactors = FALSE +) + +svm_poly_modes <- c("classification", "regression", "unknown") + +svm_poly_engines <- data.frame( + kernlab = c(TRUE, TRUE, FALSE), + row.names = c("classification", "regression", "unknown") +) + +# ------------------------------------------------------------------------------ + +svm_poly_kernlab_data <- + list( + libs = "kernlab", + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list( + kernel = "polydot" + ) + ), + numeric = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + class = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + classprob = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ), + raw = list( + pre = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) + ) diff --git a/R/svm_rbf.R b/R/svm_rbf.R new file mode 100644 index 000000000..7670dfc75 --- /dev/null +++ b/R/svm_rbf.R @@ -0,0 +1,183 @@ +#' General interface for radial basis function support vector machines +#' +#' `svm_rbf` is a way to generate a _specification_ of a model +#' before fitting and allows the model to be created using +#' different packages in R or via Spark. The main arguments for the +#' model are: +#' \itemize{ +#' \item \code{cost}: The cost of predicting a sample within or on the +#' wrong side of the margin. +#' \item \code{rbf_sigma}: The precision parameter for the radial basis +#' function. +#' \item \code{margin}: The epsilon in the SVM insensitive loss function +#' (regression only) +#' } +#' These arguments are converted to their specific names at the +#' time that the model is fit. Other options and argument can be +#' set using `set_engine`. If left to their defaults +#' here (`NULL`), the values are taken from the underlying model +#' functions. If parameters need to be modified, `update` can be used +#' in lieu of recreating the object from scratch. +#' +#' @inheritParams boost_tree +#' @param mode A single character string for the type of model. +#' Possible values for this model are "unknown", "regression", or +#' "classification". +#' @param cost A positive number for the cost of predicting a sample within +#' or on the wrong side of the margin +#' @param rbf_sigma A positive number for radial basis function. +#' @param margin A positive number for the epsilon in the SVM insensitive +#' loss function (regression only) +#' @details +#' The model can be created using the `fit()` function using the +#' following _engines_: +#' \itemize{ +#' \item \pkg{R}: `"kernlab"` +#' } +#' +#' @section Engine Details: +#' +#' Engines may have pre-set default arguments when executing the +#' model fit call. For this type of +#' model, the template of the fit calls are:: +#' +#' \pkg{kernlab} classification +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} +#' +#' \pkg{kernlab} regression +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} +#' +#' @importFrom purrr map_lgl +#' @seealso [varying()], [fit()] +#' @examples +#' svm_rbf(mode = "classification", rbf_sigma = 0.2) +#' # Parameters can be represented by a placeholder: +#' svm_rbf(mode = "regression", cost = varying()) +#' @export + +svm_rbf <- + function(mode = "unknown", + cost = NULL, rbf_sigma = NULL, margin = NULL) { + + args <- list( + cost = enquo(cost), + rbf_sigma = enquo(rbf_sigma), + margin = enquo(margin) + ) + + new_model_spec( + "svm_rbf", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + } + +#' @export +print.svm_rbf <- function(x, ...) { + cat("Radial Basis Function Support Vector Machine Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if(!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + invisible(x) +} + +# ------------------------------------------------------------------------------ + +#' @export +#' @inheritParams update.boost_tree +#' @param object A radial basis function SVM model specification. +#' @examples +#' model <- svm_rbf(cost = 10, rbf_sigma = 0.1) +#' model +#' update(model, cost = 1) +#' update(model, cost = 1, fresh = TRUE) +#' @method update svm_rbf +#' @rdname svm_rbf +#' @export +update.svm_rbf <- + function(object, + cost = NULL, rbf_sigma = NULL, margin = NULL, + fresh = FALSE, + ...) { + update_dot_check(...) + + args <- list( + cost = enquo(cost), + rbf_sigma = enquo(rbf_sigma), + margin = enquo(margin) + ) + + if (fresh) { + object$args <- args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + } + + new_model_spec( + "svm_rbf", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) + } + +# ------------------------------------------------------------------------------ + +#' @export +translate.svm_rbf <- function(x, engine = x$engine, ...) { + x <- translate.default(x, engine = engine, ...) + + # slightly cleaner code using + arg_vals <- x$method$fit$args + arg_names <- names(arg_vals) + + # add checks to error trap or change things for this method + if (x$engine == "kernlab") { + + # unless otherwise specified, classification models predict probabilities + if (x$mode == "classification" && !any(arg_names == "prob.model")) + arg_vals$prob.model <- TRUE + if (x$mode == "classification" && any(arg_names == "epsilon")) + arg_vals$epsilon <- NULL + + # convert sigma and scale to a `kpar` argument. + if (any(arg_names == "sigma")) { + kpar <- expr(list()) + kpar$sigma <- arg_vals$sigma + arg_vals$sigma <- NULL + arg_vals$kpar <- kpar + } + + } + x$method$fit$args <- arg_vals + + # worried about people using this to modify the specification + x +} + +# ------------------------------------------------------------------------------ + +check_args.svm_rbf <- function(object) { + invisible(object) +} + +# ------------------------------------------------------------------------------ + +svm_reg_post <- function(results, object) { + results[,1] +} + diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R new file mode 100644 index 000000000..fdb12727d --- /dev/null +++ b/R/svm_rbf_data.R @@ -0,0 +1,69 @@ +svm_rbf_arg_key <- data.frame( + kernlab = c( "C", "sigma", "epsilon"), + row.names = c("cost", "rbf_sigma", "margin"), + stringsAsFactors = FALSE +) + +svm_rbf_modes <- c("classification", "regression", "unknown") + +svm_rbf_engines <- data.frame( + kernlab = c(TRUE, TRUE, FALSE), + row.names = c("classification", "regression", "unknown") +) + +# ------------------------------------------------------------------------------ + +svm_rbf_kernlab_data <- + list( + libs = "kernlab", + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list( + kernel = "rbfdot" + ) + ), + numeric = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + class = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + classprob = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ), + raw = list( + pre = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) + ) diff --git a/R/translate.R b/R/translate.R index 7c88e8c18..2c342e360 100644 --- a/R/translate.R +++ b/R/translate.R @@ -19,6 +19,11 @@ #' This function can be useful when you need to understand how #' `parsnip` goes from a generic model specific to a model fitting #' function. +#' +#' **Note**: this function is used internally and users should only use it +#' to understand what the underlying syntax would be. It should not be used +#' to modify the model specification. +#' #' @examples #' lm_spec <- linear_reg(penalty = 0.01) #' diff --git a/_pkgdown.yml b/_pkgdown.yml index cf5e1e4e3..b386ca244 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -26,6 +26,8 @@ reference: - nearest_neighbor - rand_forest - surv_reg + - svm_poly + - svm_rbf - title: Infrastructure contents: - descriptors diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index 9810d6296..a0752b40d 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -175,17 +175,17 @@
As an example of a model with multiple engines, here is the object for logistic regression:
parsnip:::logistic_reg_arg_key
-#> glm glmnet spark stan
-#> penalty NA lambda reg_param NA
-#> mixture NA alpha elastic_net_param NAThe internals of parsnip will use these objects during the creation of the model code.
pred moduleThis is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For pred, the model requires an unnamed numeric vector output (usually).
numeric module
+This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For numeric, the model requires an unnamed numeric vector output (usually).
For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results).
-Note that the pred module maps to the predict_num function in parsnip. However, the user-facing predict function is used to generate predictions and returns a tibble with a column named .pred (see the example below). When creating new models, you don’t have to write code for that part.
Note that the numeric module maps to the predict_numeric function in parsnip. However, the user-facing predict function is used to generate predictions and returns a tibble with a column named .pred (see the example below). When creating new models, you don’t have to write code for that part.
classes moduleTo make hard class predictions, the classes object contains the details. The elements of the list are:
class module
+To make hard class predictions, the class object contains the details. The elements of the list are:
pre and post are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won’t be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a pre function with the args below.args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.mixture_da_mda_data$classes <-
+mixture_da_mda_data$class <-
list(
pre = NULL,
post = NULL,
@@ -253,12 +253,12 @@
)
The predict_class function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the pred module, the user doesn’t call predict_class but uses predict instead and this produces a tibble with a column named .pred_class per the model guidlines.
prob moduleThis defines the class probabilities (if they can be computed). The format is identical to the classes module but the output is expected to be a tibble with columns for each factor level.
classprob module
+This defines the class probabilities (if they can be computed). The format is identical to the class module but the output is expected to be a tibble with columns for each factor level.
As an example of the post function, the data frame created by mda:::predict.mda will be converted to a tibble. The arguments are x (the raw results coming from the predict method) and object (the parsnip model fit object). The latter has a sub-object called lvl which is a character string of the outcome’s factor levels (if any).
mixture_da_mda_data$prob <-
+mixture_da_mda_data$classprob <-
list(
pre = NULL,
post = function(x, object) {
@@ -371,9 +371,9 @@
Do I have to return a simple vector for predict_num and predict_class?
-Previously, when discussing the pred information:
+Previously, when discussing the numeric information:
-For pred, the model requires an unnamed numeric vector output (usually).
+For numeric, the model requires an unnamed numeric vector output (usually).
There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated here but some easy examples are:
diff --git a/docs/reference/index.html b/docs/reference/index.html
index df9884277..3682273be 100644
--- a/docs/reference/index.html
+++ b/docs/reference/index.html
@@ -190,6 +190,18 @@ surv_reg() update(<surv_reg>)
General Interface for Parametric Survival Models
+
+
+
+
+
+ General interface for polynomial support vector machines
+
+
+
+
+
+ General interface for radial basis function support vector machines
diff --git a/docs/reference/linear_reg.html b/docs/reference/linear_reg.html
index b8d391464..b6961b5fd 100644
--- a/docs/reference/linear_reg.html
+++ b/docs/reference/linear_reg.html
@@ -33,10 +33,10 @@
Arg
penalty
- An non-negative number representing the
-total amount of regularization (glmnet and spark only).
+ An non-negative number representing the total
+amount of regularization (glmnet, keras, and spark only).
+For keras models, this corresponds to purely L2 regularization
+(aka weight decay) while the other models can be a combination
+of L1 and L2 (depending on the value of mixture).
mixture
@@ -209,6 +212,7 @@ Details
R: "lm" or "glmnet"
Stan: "stan"
Spark: "spark"
+keras: "keras"
Note
@@ -251,6 +255,11 @@ keras
+
+parsnip::keras_mlp(x = missing_arg(), y = missing_arg(), hidden_units = 1,
+ act = "linear")
+
When using glmnet models, there is the option to pass
multiple values (or no values) to the penalty argument.
This can have an effect on the model object results. When using
diff --git a/docs/reference/logistic_reg.html b/docs/reference/logistic_reg.html
index 122fc482c..626830342 100644
--- a/docs/reference/logistic_reg.html
+++ b/docs/reference/logistic_reg.html
@@ -33,8 +33,8 @@
Arg
penalty
- An non-negative number representing the
-total amount of regularization (glmnet and spark only).
+ An non-negative number representing the total
+amount of regularization (glmnet, keras, and spark only).
+For keras models, this corresponds to purely L2 regularization
+(aka weight decay) while the other models can be a combination
+of L1 and L2 (depending on the value of mixture).
mixture
@@ -207,6 +210,7 @@ Details
R: "glm" or "glmnet"
Stan: "stan"
Spark: "spark"
+keras: "keras"
Note
@@ -250,6 +254,11 @@
+keras
+
+parsnip::keras_mlp(x = missing_arg(), y = missing_arg(), hidden_units = 1,
+ act = "linear")
+
When using glmnet models, there is the option to pass
multiple values (or no values) to the penalty argument.
This can have an effect on the model object results. When using
diff --git a/docs/reference/multinom_reg.html b/docs/reference/multinom_reg.html
index 6765274a5..87e2cfd61 100644
--- a/docs/reference/multinom_reg.html
+++ b/docs/reference/multinom_reg.html
@@ -33,7 +33,7 @@
Arg
penalty
- An non-negative number representing the
-total amount of regularization.
+ An non-negative number representing the total
+amount of regularization (glmnet, keras, and spark only).
+For keras models, this corresponds to purely L2 regularization
+(aka weight decay) while the other models can be a combination
+of L1 and L2 (depending on the value of mixture).
mixture
@@ -206,6 +209,7 @@ Details
following engines:
R: "glmnet"
Stan: "stan"
+keras: "keras"
Note
@@ -239,6 +243,11 @@
+keras
+
+parsnip::keras_mlp(x = missing_arg(), y = missing_arg(), hidden_units = 1,
+ act = "linear")
+
When using glmnet models, there is the option to pass
multiple values (or no values) to the penalty argument.
This can have an effect on the model object results. When using
diff --git a/docs/reference/other_predict.html b/docs/reference/other_predict.html
index 15d418104..c0a7c3c1c 100644
--- a/docs/reference/other_predict.html
+++ b/docs/reference/other_predict.html
@@ -155,9 +155,9 @@
Other predict methods.
predict_predint(object, ...)
# S3 method for model_fit
-predict_num(object, new_data, ...)
+predict_numeric(object, new_data, ...)
-predict_num(object, ...)
+predict_numeric(object, ...)
# S3 method for model_fit
predict_quantile(object, new_data,
diff --git a/docs/reference/svm_poly.html b/docs/reference/svm_poly.html
new file mode 100644
index 000000000..083ceedcc
--- /dev/null
+++ b/docs/reference/svm_poly.html
@@ -0,0 +1,304 @@
+
+
+
+
+
+
+
+
+General interface for polynomial support vector machines — svm_poly • parsnip
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ General interface for polynomial support vector machines
+
+ svm_poly.Rd
+
+
+
+
+ svm_poly is a way to generate a specification of a model
+before fitting and allows the model to be created using
+different packages in R or via Spark. The main arguments for the
+model are:
+cost: The cost of predicting a sample within or on the
+wrong side of the margin.
+degree: The polynomial degree.
+scale_factor: A scaling factor for the kernel.
+margin: The epsilon in the SVM insensitive loss function
+(regression only)
+
These arguments are converted to their specific names at the
+time that the model is fit. Other options and argument can be
+set using set_engine. If left to their defaults
+here (NULL), the values are taken from the underlying model
+functions. If parameters need to be modified, update can be used
+in lieu of recreating the object from scratch.
+
+
+
+ svm_poly(mode = "unknown", cost = NULL, degree = NULL,
+ scale_factor = NULL, margin = NULL)
+
+# S3 method for svm_poly
+update(object, cost = NULL, degree = NULL,
+ scale_factor = NULL, margin = NULL, fresh = FALSE, ...)
+
+ Arguments
+
+
+
+ mode
+ A single character string for the type of model.
+Possible values for this model are "unknown", "regression", or
+"classification".
+
+
+ cost
+ A positive number for the cost of predicting a sample within
+or on the wrong side of the margin
+
+
+ degree
+ A positive number for polynomial degree.
+
+
+ scale_factor
+ A positive number for the polynomial scaling factor.
+
+
+ margin
+ A positive number for the epsilon in the SVM insensitive
+loss function (regression only)
+
+
+ object
+ A polynomial SVM model specification.
+
+
+ fresh
+ A logical for whether the arguments should be
+modified in-place of or replaced wholesale.
+
+
+ ...
+ Not used for update.
+
+
+
+ Details
+
+ The model can be created using the fit() function using the
+following engines:
+R: "kernlab"
+
+
+ Engine Details
+
+
+ Engines may have pre-set default arguments when executing the
+model fit call. For this type of
+model, the template of the fit calls are::
+kernlab classification
+
+kernlab::ksvm(x = missing_arg(), y = missing_arg(), kernel = "polydot",
+ prob.model = TRUE)
+
+kernlab regression
+
+kernlab::ksvm(x = missing_arg(), y = missing_arg(), kernel = "polydot")
+
+
+ See also
+
+
+
+
+ Examples
+ svm_poly(mode = "classification", degree = 1.2)#> Polynomial Support Vector Machine Specification (classification)
+#>
+#> Main Arguments:
+#> degree = 1.2
+#> #> Polynomial Support Vector Machine Specification (regression)
+#>
+#> Main Arguments:
+#> cost = varying()
+#> model <- svm_poly(cost = 10, scale_factor = 0.1)
+model#> Polynomial Support Vector Machine Specification (unknown)
+#>
+#> Main Arguments:
+#> cost = 10
+#> scale_factor = 0.1
+#> update(model, cost = 1)#> Polynomial Support Vector Machine Specification (unknown)
+#>
+#> Main Arguments:
+#> cost = 1
+#> scale_factor = 0.1
+#> update(model, cost = 1, fresh = TRUE)#> Polynomial Support Vector Machine Specification (unknown)
+#>
+#> Main Arguments:
+#> cost = 1
+#>
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/reference/svm_rbf.html b/docs/reference/svm_rbf.html
new file mode 100644
index 000000000..661ff3f5c
--- /dev/null
+++ b/docs/reference/svm_rbf.html
@@ -0,0 +1,300 @@
+
+
+
+
+
+
+
+
+General interface for radial basis function support vector machines — svm_rbf • parsnip
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ General interface for radial basis function support vector machines
+
+ svm_rbf.Rd
+
+
+
+
+ svm_rbf is a way to generate a specification of a model
+before fitting and allows the model to be created using
+different packages in R or via Spark. The main arguments for the
+model are:
+cost: The cost of predicting a sample within or on the
+wrong side of the margin.
+rbf_sigma: The precision parameter for the radial basis
+function.
+margin: The epsilon in the SVM insensitive loss function
+(regression only)
+
These arguments are converted to their specific names at the
+time that the model is fit. Other options and argument can be
+set using set_engine. If left to their defaults
+here (NULL), the values are taken from the underlying model
+functions. If parameters need to be modified, update can be used
+in lieu of recreating the object from scratch.
+
+
+
+ svm_rbf(mode = "unknown", cost = NULL, rbf_sigma = NULL,
+ margin = NULL)
+
+# S3 method for svm_rbf
+update(object, cost = NULL, rbf_sigma = NULL,
+ margin = NULL, fresh = FALSE, ...)
+
+ Arguments
+
+
+
+ mode
+ A single character string for the type of model.
+Possible values for this model are "unknown", "regression", or
+"classification".
+
+
+ cost
+ A positive number for the cost of predicting a sample within
+or on the wrong side of the margin
+
+
+ rbf_sigma
+ A positive number for radial basis function.
+
+
+ margin
+ A positive number for the epsilon in the SVM insensitive
+loss function (regression only)
+
+
+ object
+ A radial basis function SVM model specification.
+
+
+ fresh
+ A logical for whether the arguments should be
+modified in-place of or replaced wholesale.
+
+
+ ...
+ Not used for update.
+
+
+
+ Details
+
+ The model can be created using the fit() function using the
+following engines:
+R: "kernlab"
+
+
+ Engine Details
+
+
+ Engines may have pre-set default arguments when executing the
+model fit call. For this type of
+model, the template of the fit calls are::
+kernlab classification
+
+kernlab::ksvm(x = missing_arg(), y = missing_arg(), kernel = "rbfdot",
+ prob.model = TRUE)
+
+kernlab regression
+
+kernlab::ksvm(x = missing_arg(), y = missing_arg(), kernel = "rbfdot")
+
+
+ See also
+
+
+
+
+ Examples
+ svm_rbf(mode = "classification", rbf_sigma = 0.2)#> Radial Basis Function Support Vector Machine Specification (classification)
+#>
+#> Main Arguments:
+#> rbf_sigma = 0.2
+#> #> Radial Basis Function Support Vector Machine Specification (regression)
+#>
+#> Main Arguments:
+#> cost = varying()
+#> model <- svm_rbf(cost = 10, rbf_sigma = 0.1)
+model#> Radial Basis Function Support Vector Machine Specification (unknown)
+#>
+#> Main Arguments:
+#> cost = 10
+#> rbf_sigma = 0.1
+#> update(model, cost = 1)#> Radial Basis Function Support Vector Machine Specification (unknown)
+#>
+#> Main Arguments:
+#> cost = 1
+#> rbf_sigma = 0.1
+#> update(model, cost = 1, fresh = TRUE)#> Radial Basis Function Support Vector Machine Specification (unknown)
+#>
+#> Main Arguments:
+#> cost = 1
+#>
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/reference/translate.html b/docs/reference/translate.html
index 3c3abc423..965fc7ff6 100644
--- a/docs/reference/translate.html
+++ b/docs/reference/translate.html
@@ -163,6 +163,9 @@ Details
This function can be useful when you need to understand how
parsnip goes from a generic model specific to a model fitting
function.
+Note: this function is used internally and users should only use it
+to understand what the underlying syntax would be. It should not be used
+to modify the model specification.
Examples
diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd
index f395732d7..5c00e5f60 100644
--- a/man/linear_reg.Rd
+++ b/man/linear_reg.Rd
@@ -14,8 +14,11 @@ linear_reg(mode = "regression", penalty = NULL, mixture = NULL)
\item{mode}{A single character string for the type of model.
The only possible value for this model is "regression".}
-\item{penalty}{An non-negative number representing the
-total amount of regularization (\code{glmnet} and \code{spark} only).}
+\item{penalty}{An non-negative number representing the total
+amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only).
+For \code{keras} models, this corresponds to purely L2 regularization
+(aka weight decay) while the other models can be a combination
+of L1 and L2 (depending on the value of \code{mixture}).}
\item{mixture}{A number between zero and one (inclusive) that
represents the proportion of regularization that is used for the
@@ -32,11 +35,11 @@ modified in-place of or replaced wholesale.}
\description{
\code{linear_reg} is a way to generate a \emph{specification} of a model
before fitting and allows the model to be created using
-different packages in R, Stan, or via Spark. The main arguments for the
-model are:
+different packages in R, Stan, keras, or via Spark. The main
+arguments for the model are:
\itemize{
\item \code{penalty}: The total amount of regularization
-in the model. Note that this must be zero for some engines .
+in the model. Note that this must be zero for some engines.
\item \code{mixture}: The proportion of L1 regularization in
the model. Note that this will be ignored for some engines.
}
@@ -58,6 +61,7 @@ following \emph{engines}:
\item \pkg{R}: \code{"lm"} or \code{"glmnet"}
\item \pkg{Stan}: \code{"stan"}
\item \pkg{Spark}: \code{"spark"}
+\item \pkg{keras}: \code{"keras"}
}
}
\note{
@@ -97,6 +101,10 @@ model, the template of the fit calls are:
\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
+\pkg{keras}
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
+
When using \code{glmnet} models, there is the option to pass
multiple values (or no values) to the \code{penalty} argument.
This can have an effect on the model object results. When using
diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd
index 0b2918a46..0a3416c0c 100644
--- a/man/logistic_reg.Rd
+++ b/man/logistic_reg.Rd
@@ -14,8 +14,11 @@ logistic_reg(mode = "classification", penalty = NULL, mixture = NULL)
\item{mode}{A single character string for the type of model.
The only possible value for this model is "classification".}
-\item{penalty}{An non-negative number representing the
-total amount of regularization (\code{glmnet} and \code{spark} only).}
+\item{penalty}{An non-negative number representing the total
+amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only).
+For \code{keras} models, this corresponds to purely L2 regularization
+(aka weight decay) while the other models can be a combination
+of L1 and L2 (depending on the value of \code{mixture}).}
\item{mixture}{A number between zero and one (inclusive) that
represents the proportion of regularization that is used for the
@@ -32,8 +35,8 @@ modified in-place of or replaced wholesale.}
\description{
\code{logistic_reg} is a way to generate a \emph{specification} of a model
before fitting and allows the model to be created using
-different packages in R, Stan, or via Spark. The main arguments for the
-model are:
+different packages in R, Stan, keras, or via Spark. The main
+arguments for the model are:
\itemize{
\item \code{penalty}: The total amount of regularization
in the model. Note that this must be zero for some engines.
@@ -56,6 +59,7 @@ following \emph{engines}:
\item \pkg{R}: \code{"glm"} or \code{"glmnet"}
\item \pkg{Stan}: \code{"stan"}
\item \pkg{Spark}: \code{"spark"}
+\item \pkg{keras}: \code{"keras"}
}
}
\note{
@@ -95,6 +99,10 @@ model, the template of the fit calls are:
\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")}
+\pkg{keras}
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
+
When using \code{glmnet} models, there is the option to pass
multiple values (or no values) to the \code{penalty} argument.
This can have an effect on the model object results. When using
diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd
index 2bcc20b3e..0de9c5ee1 100644
--- a/man/multinom_reg.Rd
+++ b/man/multinom_reg.Rd
@@ -14,8 +14,11 @@ multinom_reg(mode = "classification", penalty = NULL, mixture = NULL)
\item{mode}{A single character string for the type of model.
The only possible value for this model is "classification".}
-\item{penalty}{An non-negative number representing the
-total amount of regularization.}
+\item{penalty}{An non-negative number representing the total
+amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only).
+For \code{keras} models, this corresponds to purely L2 regularization
+(aka weight decay) while the other models can be a combination
+of L1 and L2 (depending on the value of \code{mixture}).}
\item{mixture}{A number between zero and one (inclusive) that
represents the proportion of regularization that is used for the
@@ -32,7 +35,7 @@ modified in-place of or replaced wholesale.}
\description{
\code{multinom_reg} is a way to generate a \emph{specification} of a model
before fitting and allows the model to be created using
-different packages in R or Spark. The main arguments for the
+different packages in R, keras, or Spark. The main arguments for the
model are:
\itemize{
\item \code{penalty}: The total amount of regularization
@@ -55,6 +58,7 @@ following \emph{engines}:
\itemize{
\item \pkg{R}: \code{"glmnet"}
\item \pkg{Stan}: \code{"stan"}
+\item \pkg{keras}: \code{"keras"}
}
}
\note{
@@ -86,6 +90,10 @@ model, the template of the fit calls are:
\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")}
+\pkg{keras}
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")}
+
When using \code{glmnet} models, there is the option to pass
multiple values (or no values) to the \code{penalty} argument.
This can have an effect on the model object results. When using
diff --git a/man/other_predict.Rd b/man/other_predict.Rd
index f462f4d0b..57e3bf3f2 100644
--- a/man/other_predict.Rd
+++ b/man/other_predict.Rd
@@ -1,6 +1,6 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/predict_class.R, R/predict_classprob.R,
-% R/predict_interval.R, R/predict_num.R, R/predict_quantile.R
+% R/predict_interval.R, R/predict_numeric.R, R/predict_quantile.R
\name{predict_class.model_fit}
\alias{predict_class.model_fit}
\alias{predict_class}
@@ -10,8 +10,8 @@
\alias{predict_confint}
\alias{predict_predint.model_fit}
\alias{predict_predint}
-\alias{predict_num.model_fit}
-\alias{predict_num}
+\alias{predict_numeric.model_fit}
+\alias{predict_numeric}
\alias{predict_quantile.model_fit}
\alias{predict_quantile}
\title{Other predict methods.}
@@ -34,9 +34,9 @@ predict_confint(object, ...)
predict_predint(object, ...)
-\method{predict_num}{model_fit}(object, new_data, ...)
+\method{predict_numeric}{model_fit}(object, new_data, ...)
-predict_num(object, ...)
+predict_numeric(object, ...)
\method{predict_quantile}{model_fit}(object, new_data,
quantile = (1:9)/10, ...)
diff --git a/man/svm_poly.Rd b/man/svm_poly.Rd
new file mode 100644
index 000000000..66a0f361d
--- /dev/null
+++ b/man/svm_poly.Rd
@@ -0,0 +1,90 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/svm_poly.R
+\name{svm_poly}
+\alias{svm_poly}
+\alias{update.svm_poly}
+\title{General interface for polynomial support vector machines}
+\usage{
+svm_poly(mode = "unknown", cost = NULL, degree = NULL,
+ scale_factor = NULL, margin = NULL)
+
+\method{update}{svm_poly}(object, cost = NULL, degree = NULL,
+ scale_factor = NULL, margin = NULL, fresh = FALSE, ...)
+}
+\arguments{
+\item{mode}{A single character string for the type of model.
+Possible values for this model are "unknown", "regression", or
+"classification".}
+
+\item{cost}{A positive number for the cost of predicting a sample within
+or on the wrong side of the margin}
+
+\item{degree}{A positive number for polynomial degree.}
+
+\item{scale_factor}{A positive number for the polynomial scaling factor.}
+
+\item{margin}{A positive number for the epsilon in the SVM insensitive
+loss function (regression only)}
+
+\item{object}{A polynomial SVM model specification.}
+
+\item{fresh}{A logical for whether the arguments should be
+modified in-place of or replaced wholesale.}
+
+\item{...}{Not used for \code{update}.}
+}
+\description{
+\code{svm_poly} is a way to generate a \emph{specification} of a model
+before fitting and allows the model to be created using
+different packages in R or via Spark. The main arguments for the
+model are:
+\itemize{
+\item \code{cost}: The cost of predicting a sample within or on the
+wrong side of the margin.
+\item \code{degree}: The polynomial degree.
+\item \code{scale_factor}: A scaling factor for the kernel.
+\item \code{margin}: The epsilon in the SVM insensitive loss function
+(regression only)
+}
+These arguments are converted to their specific names at the
+time that the model is fit. Other options and argument can be
+set using \code{set_engine}. If left to their defaults
+here (\code{NULL}), the values are taken from the underlying model
+functions. If parameters need to be modified, \code{update} can be used
+in lieu of recreating the object from scratch.
+}
+\details{
+The model can be created using the \code{fit()} function using the
+following \emph{engines}:
+\itemize{
+\item \pkg{R}: \code{"kernlab"}
+}
+}
+\section{Engine Details}{
+
+
+Engines may have pre-set default arguments when executing the
+model fit call. For this type of
+model, the template of the fit calls are::
+
+\pkg{kernlab} classification
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")}
+
+\pkg{kernlab} regression
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")}
+}
+
+\examples{
+svm_poly(mode = "classification", degree = 1.2)
+# Parameters can be represented by a placeholder:
+svm_poly(mode = "regression", cost = varying())
+model <- svm_poly(cost = 10, scale_factor = 0.1)
+model
+update(model, cost = 1)
+update(model, cost = 1, fresh = TRUE)
+}
+\seealso{
+\code{\link[=varying]{varying()}}, \code{\link[=fit]{fit()}}
+}
diff --git a/man/svm_rbf.Rd b/man/svm_rbf.Rd
new file mode 100644
index 000000000..76fa58d61
--- /dev/null
+++ b/man/svm_rbf.Rd
@@ -0,0 +1,88 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/svm_rbf.R
+\name{svm_rbf}
+\alias{svm_rbf}
+\alias{update.svm_rbf}
+\title{General interface for radial basis function support vector machines}
+\usage{
+svm_rbf(mode = "unknown", cost = NULL, rbf_sigma = NULL,
+ margin = NULL)
+
+\method{update}{svm_rbf}(object, cost = NULL, rbf_sigma = NULL,
+ margin = NULL, fresh = FALSE, ...)
+}
+\arguments{
+\item{mode}{A single character string for the type of model.
+Possible values for this model are "unknown", "regression", or
+"classification".}
+
+\item{cost}{A positive number for the cost of predicting a sample within
+or on the wrong side of the margin}
+
+\item{rbf_sigma}{A positive number for radial basis function.}
+
+\item{margin}{A positive number for the epsilon in the SVM insensitive
+loss function (regression only)}
+
+\item{object}{A radial basis function SVM model specification.}
+
+\item{fresh}{A logical for whether the arguments should be
+modified in-place of or replaced wholesale.}
+
+\item{...}{Not used for \code{update}.}
+}
+\description{
+\code{svm_rbf} is a way to generate a \emph{specification} of a model
+before fitting and allows the model to be created using
+different packages in R or via Spark. The main arguments for the
+model are:
+\itemize{
+\item \code{cost}: The cost of predicting a sample within or on the
+wrong side of the margin.
+\item \code{rbf_sigma}: The precision parameter for the radial basis
+function.
+\item \code{margin}: The epsilon in the SVM insensitive loss function
+(regression only)
+}
+These arguments are converted to their specific names at the
+time that the model is fit. Other options and argument can be
+set using \code{set_engine}. If left to their defaults
+here (\code{NULL}), the values are taken from the underlying model
+functions. If parameters need to be modified, \code{update} can be used
+in lieu of recreating the object from scratch.
+}
+\details{
+The model can be created using the \code{fit()} function using the
+following \emph{engines}:
+\itemize{
+\item \pkg{R}: \code{"kernlab"}
+}
+}
+\section{Engine Details}{
+
+
+Engines may have pre-set default arguments when executing the
+model fit call. For this type of
+model, the template of the fit calls are::
+
+\pkg{kernlab} classification
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")}
+
+\pkg{kernlab} regression
+
+\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")}
+}
+
+\examples{
+svm_rbf(mode = "classification", rbf_sigma = 0.2)
+# Parameters can be represented by a placeholder:
+svm_rbf(mode = "regression", cost = varying())
+model <- svm_rbf(cost = 10, rbf_sigma = 0.1)
+model
+update(model, cost = 1)
+update(model, cost = 1, fresh = TRUE)
+}
+\seealso{
+\code{\link[=varying]{varying()}}, \code{\link[=fit]{fit()}}
+}
diff --git a/man/translate.Rd b/man/translate.Rd
index a0bf27d18..82442df66 100644
--- a/man/translate.Rd
+++ b/man/translate.Rd
@@ -29,6 +29,10 @@ the model fitting function/engine.
This function can be useful when you need to understand how
\code{parsnip} goes from a generic model specific to a model fitting
function.
+
+\strong{Note}: this function is used internally and users should only use it
+to understand what the underlying syntax would be. It should not be used
+to modify the model specification.
}
\examples{
lm_spec <- linear_reg(penalty = 0.01)
diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R
index 5c3e6096e..c1f189d7c 100644
--- a/tests/testthat/test_boost_tree_C50.R
+++ b/tests/testthat/test_boost_tree_C50.R
@@ -121,5 +121,10 @@ test_that('submodel prediction', {
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 4, type = "prob")
mp_res <- do.call("rbind", mp_res$.pred)
expect_equal(mp_res[[".pred_No"]], unname(pred_class[, "No"]))
+
+ expect_error(
+ multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 4, type = "prob"),
+ "Did you mean"
+ )
})
diff --git a/tests/testthat/test_boost_tree_spark.R b/tests/testthat/test_boost_tree_spark.R
index 6c517e3b8..ce148620c 100644
--- a/tests/testthat/test_boost_tree_spark.R
+++ b/tests/testthat/test_boost_tree_spark.R
@@ -58,7 +58,7 @@ test_that('spark execution', {
)
expect_error(
- spark_reg_pred_num <- predict_num(spark_reg_fit, iris_bt_te),
+ spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_bt_te),
regexp = NA
)
@@ -68,7 +68,7 @@ test_that('spark execution', {
)
expect_error(
- spark_reg_num_dup <- predict_num(spark_reg_fit_dup, iris_bt_te),
+ spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_bt_te),
regexp = NA
)
diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R
index f8a7f7aa1..0c6be9417 100644
--- a/tests/testthat/test_boost_tree_xgboost.R
+++ b/tests/testthat/test_boost_tree_xgboost.R
@@ -141,7 +141,7 @@ test_that('xgboost regression prediction', {
)
xy_pred <- predict(xy_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])))
- expect_equal(xy_pred, predict_num(xy_fit, new_data = mtcars[1:8, -1]))
+ expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1]))
form_fit <- fit(
car_basic,
@@ -151,7 +151,7 @@ test_that('xgboost regression prediction', {
)
form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])))
- expect_equal(form_pred, predict_num(form_fit, new_data = mtcars[1:8, -1]))
+ expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1]))
})
@@ -188,5 +188,10 @@ test_that('submodel prediction', {
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 5, type = "prob")
mp_res <- do.call("rbind", mp_res$.pred)
expect_equal(mp_res[[".pred_No"]], pred_class)
+
+ expect_error(
+ multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 5, type = "prob"),
+ "Did you mean"
+ )
})
diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R
index 9df468114..86d2e2bed 100644
--- a/tests/testthat/test_linear_reg.R
+++ b/tests/testthat/test_linear_reg.R
@@ -269,7 +269,7 @@ test_that('lm prediction', {
control = ctrl
)
- expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]))
+ expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]))
res_form <- fit(
iris_basic,
@@ -277,7 +277,7 @@ test_that('lm prediction', {
data = iris,
control = ctrl
)
- expect_equal(inl_pred, predict_num(res_form, iris[1:5, ]))
+ expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ]))
res_mv <- fit(
iris_basic,
@@ -285,7 +285,7 @@ test_that('lm prediction', {
data = iris,
control = ctrl
)
- expect_equal(mv_pred, predict_num(res_mv, iris[1:5,]))
+ expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,]))
})
test_that('lm intervals', {
@@ -322,3 +322,14 @@ test_that('lm intervals', {
expect_equivalent(prediction_parsnip$.pred_upper, prediction_lm[, "upr"])
})
+
+test_that('newdata error trapping', {
+ res_xy <- fit_xy(
+ iris_basic,
+ x = iris[, num_pred],
+ y = iris$Sepal.Length,
+ control = ctrl
+ )
+ expect_error(predict(res_xy, newdata = iris[1:3, num_pred]), "Did you mean")
+})
+
diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R
index ba7458038..a22b6e73d 100644
--- a/tests/testthat/test_linear_reg_glmnet.R
+++ b/tests/testthat/test_linear_reg_glmnet.R
@@ -69,7 +69,7 @@ test_that('glmnet prediction, single lambda', {
s = iris_basic$spec$args$penalty)
uni_pred <- unname(uni_pred[,1])
- expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]))
+ expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]))
res_form <- fit(
iris_basic,
@@ -87,7 +87,7 @@ test_that('glmnet prediction, single lambda', {
s = res_form$spec$spec$args$penalty)
form_pred <- unname(form_pred[,1])
- expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
+ expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")]))
})
@@ -115,7 +115,7 @@ test_that('glmnet prediction, multiple lambda', {
mult_pred$lambda <- rep(lams, each = 5)
mult_pred <- mult_pred[,-2]
- expect_equal(mult_pred, predict_num(res_xy, iris[1:5, num_pred]))
+ expect_equal(mult_pred, predict_numeric(res_xy, iris[1:5, num_pred]))
res_form <- fit(
iris_mult,
@@ -135,7 +135,7 @@ test_that('glmnet prediction, multiple lambda', {
form_pred$lambda <- rep(lams, each = 5)
form_pred <- form_pred[,-2]
- expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
+ expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")]))
})
test_that('glmnet prediction, all lambda', {
@@ -157,7 +157,7 @@ test_that('glmnet prediction, all lambda', {
all_pred$lambda <- rep(res_xy$fit$lambda, each = 5)
all_pred <- all_pred[,-2]
- expect_equal(all_pred, predict_num(res_xy, iris[1:5, num_pred]))
+ expect_equal(all_pred, predict_numeric(res_xy, iris[1:5, num_pred]))
# test that the lambda seq is in the right order (since no docs on this)
tmp_pred <- predict(res_xy$fit, newx = as.matrix(iris[1:5, num_pred]),
@@ -180,7 +180,7 @@ test_that('glmnet prediction, all lambda', {
form_pred$lambda <- rep(res_form$fit$lambda, each = 5)
form_pred <- form_pred[,-2]
- expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
+ expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")]))
})
@@ -198,5 +198,10 @@ test_that('submodel prediction', {
mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], penalty = .1)
mp_res <- do.call("rbind", mp_res$.pred)
expect_equal(mp_res[[".pred"]], unname(pred_glmn[,1]))
+
+ expect_error(
+ multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1),
+ "Did you mean"
+ )
})
diff --git a/tests/testthat/test_linear_reg_keras.R b/tests/testthat/test_linear_reg_keras.R
new file mode 100644
index 000000000..c8bcfafb3
--- /dev/null
+++ b/tests/testthat/test_linear_reg_keras.R
@@ -0,0 +1,126 @@
+library(testthat)
+library(parsnip)
+library(rlang)
+library(tibble)
+
+# ------------------------------------------------------------------------------
+
+context("keras linear regression")
+source("helpers.R")
+
+# ------------------------------------------------------------------------------
+
+basic_mod <-
+ linear_reg() %>%
+ set_engine("keras", epochs = 50, verbose = 0)
+
+ridge_mod <-
+ linear_reg(penalty = 0.1) %>%
+ set_engine("keras", epochs = 50, verbose = 0)
+
+ctrl <- fit_control(verbosity = 0, catch = FALSE)
+
+# ------------------------------------------------------------------------------
+
+test_that('model fitting', {
+
+ skip_if_not_installed("keras")
+
+ set.seed(257)
+ expect_error(
+ fit1 <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ ),
+ regexp = NA
+ )
+
+ set.seed(257)
+ expect_error(
+ fit2 <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ ),
+ regexp = NA
+ )
+ expect_equal(fit1, fit2)
+
+ expect_error(
+ fit(
+ basic_mod,
+ Sepal.Length ~ .,
+ data = iris[, -5],
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit1 <-
+ fit_xy(
+ ridge_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ ridge_mod,
+ Sepal.Length ~ .,
+ data = iris[, -5],
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('regression prediction', {
+
+ skip_if_not_installed("keras")
+
+ library(keras)
+
+ set.seed(257)
+ lm_fit <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ )
+
+ keras_pred <-
+ predict(lm_fit$fit, as.matrix(iris[1:3,2:4])) %>%
+ as_tibble() %>%
+ setNames(".pred")
+ parsnip_pred <- predict(lm_fit, iris[1:3,2:4])
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+ set.seed(257)
+ rr_fit <-
+ fit_xy(
+ ridge_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ )
+
+ keras_pred <-
+ predict(rr_fit$fit, as.matrix(iris[1:3,2:4])) %>%
+ as_tibble() %>%
+ setNames(".pred")
+ parsnip_pred <- predict(rr_fit, iris[1:3,2:4])
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+})
diff --git a/tests/testthat/test_linear_reg_spark.R b/tests/testthat/test_linear_reg_spark.R
index 2ed55d7d6..f1a033beb 100644
--- a/tests/testthat/test_linear_reg_spark.R
+++ b/tests/testthat/test_linear_reg_spark.R
@@ -42,7 +42,7 @@ test_that('spark execution', {
)
expect_error(
- spark_pred_num <- predict_num(spark_fit, iris_linreg_te),
+ spark_pred_num <- predict_numeric(spark_fit, iris_linreg_te),
regexp = NA
)
diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R
index 4f7288d90..bf7bd53de 100644
--- a/tests/testthat/test_linear_reg_stan.R
+++ b/tests/testthat/test_linear_reg_stan.R
@@ -70,7 +70,7 @@ test_that('stan prediction', {
control = quiet_ctrl
)
- expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]), tolerance = 0.001)
+ expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]), tolerance = 0.001)
res_form <- fit(
iris_basic,
@@ -78,7 +78,7 @@ test_that('stan prediction', {
data = iris,
control = quiet_ctrl
)
- expect_equal(inl_pred, predict_num(res_form, iris[1:5, ]), tolerance = 0.001)
+ expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ]), tolerance = 0.001)
})
diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R
index 9b751cf46..ed4d0ff77 100644
--- a/tests/testthat/test_logistic_reg_glmnet.R
+++ b/tests/testthat/test_logistic_reg_glmnet.R
@@ -331,6 +331,11 @@ test_that('submodel prediction', {
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], penalty = .1, type = "prob")
mp_res <- do.call("rbind", mp_res$.pred)
expect_equal(mp_res[[".pred_No"]], unname(pred_glmn[,1]))
+
+ expect_error(
+ multi_predict(class_fit, newdata = wa_churn[1:4, vars], penalty = .1, type = "prob"),
+ "Did you mean"
+ )
})
diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R
new file mode 100644
index 000000000..8a7c51986
--- /dev/null
+++ b/tests/testthat/test_logistic_reg_keras.R
@@ -0,0 +1,189 @@
+library(testthat)
+library(parsnip)
+library(rlang)
+library(tibble)
+library(dplyr)
+
+# ------------------------------------------------------------------------------
+
+context("keras logistic regression")
+source("helpers.R")
+
+# ------------------------------------------------------------------------------
+
+data("lending_club")
+set.seed(352)
+dat <-
+ lending_club %>%
+ group_by(Class) %>%
+ sample_n(500) %>%
+ ungroup() %>%
+ dplyr::select(Class, funded_amnt, int_rate)
+dat <- dat[order(runif(nrow(dat))),]
+
+tr_dat <- dat[1:995, ]
+te_dat <- dat[996:1000, ]
+
+# ------------------------------------------------------------------------------
+
+basic_mod <-
+ logistic_reg() %>%
+ set_engine("keras", epochs = 50, verbose = 0)
+
+reg_mod <-
+ logistic_reg(penalty = 0.1) %>%
+ set_engine("keras", epochs = 50, verbose = 0)
+
+ctrl <- fit_control(verbosity = 0, catch = FALSE)
+
+# ------------------------------------------------------------------------------
+
+test_that('model fitting', {
+
+ skip_if_not_installed("keras")
+
+ set.seed(257)
+ expect_error(
+ fit1 <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ ),
+ regexp = NA
+ )
+
+ set.seed(257)
+ expect_error(
+ fit2 <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ ),
+ regexp = NA
+ )
+ expect_equal(fit1, fit2)
+
+ expect_error(
+ fit(
+ basic_mod,
+ Class ~ .,
+ data = tr_dat,
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit1 <-
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ reg_mod,
+ Class ~ .,
+ data = tr_dat,
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('classification prediction', {
+
+ skip_if_not_installed("keras")
+
+ library(keras)
+
+ set.seed(257)
+ lr_fit <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ )
+
+ keras_raw <-
+ predict(lr_fit$fit, as.matrix(te_dat[, -1]))
+ keras_pred <-
+ tibble(.pred_class = apply(keras_raw, 1, which.max)) %>%
+ mutate(.pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl))
+
+ parsnip_pred <- predict(lr_fit, te_dat[, -1])
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+ set.seed(257)
+ plrfit <-
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ )
+
+ keras_raw <-
+ predict(plrfit$fit, as.matrix(te_dat[, -1]))
+ keras_pred <-
+ tibble(.pred_class = apply(keras_raw, 1, which.max)) %>%
+ mutate(.pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl))
+ parsnip_pred <- predict(plrfit, te_dat[, -1])
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+})
+
+
+test_that('classification probabilities', {
+
+ skip_if_not_installed("keras")
+
+ library(keras)
+
+ set.seed(257)
+ lr_fit <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ )
+
+ keras_pred <-
+ predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>%
+ as_tibble() %>%
+ setNames(paste0(".pred_", lr_fit$lvl))
+
+ parsnip_pred <- predict(lr_fit, te_dat[, -1], type = "prob")
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+ set.seed(257)
+ plrfit <-
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = tr_dat[, -1],
+ y = tr_dat$Class
+ )
+
+ keras_pred <-
+ predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>%
+ as_tibble() %>%
+ setNames(paste0(".pred_", lr_fit$lvl))
+ parsnip_pred <- predict(plrfit, te_dat[, -1], type = "prob")
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+})
+
+
diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R
index bf3d35c11..ed3505ceb 100644
--- a/tests/testthat/test_mars.R
+++ b/tests/testthat/test_mars.R
@@ -180,7 +180,7 @@ test_that('mars prediction', {
control = ctrl
)
- expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]))
+ expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]))
res_form <- fit(
iris_basic,
@@ -188,7 +188,7 @@ test_that('mars prediction', {
data = iris,
control = ctrl
)
- expect_equal(inl_pred, predict_num(res_form, iris[1:5, ]))
+ expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ]))
res_mv <- fit(
iris_basic,
@@ -196,7 +196,7 @@ test_that('mars prediction', {
data = iris,
control = ctrl
)
- expect_equal(mv_pred, predict_num(res_mv, iris[1:5,]))
+ expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,]))
})
@@ -242,6 +242,11 @@ test_that('submodel prediction', {
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], num_terms = 5, type = "prob")
mp_res <- do.call("rbind", mp_res$.pred)
expect_equal(mp_res[[".pred_No"]], pruned_cls_pred)
+
+ expect_error(
+ multi_predict(reg_fit, newdata = mtcars[1:4, -1], num_terms = 5),
+ "Did you mean"
+ )
})
diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R
index 9e1a55448..09ae0cb61 100644
--- a/tests/testthat/test_mlp_keras.R
+++ b/tests/testthat/test_mlp_keras.R
@@ -218,7 +218,7 @@ test_that('multivariate nnet formula', {
data = nn_dat[-(1:5),]
)
expect_equal(length(unlist(keras::get_weights(nnet_form$fit))), 24)
- nnet_form_pred <- predict_num(nnet_form, new_data = nn_dat[1:5, -(1:3)])
+ nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)])
expect_equal(ncol(nnet_form_pred), 3)
expect_equal(nrow(nnet_form_pred), 5)
expect_equal(names(nnet_form_pred), c("V1", "V2", "V3"))
@@ -233,7 +233,7 @@ test_that('multivariate nnet formula', {
y = nn_dat[-(1:5), 1:3 ]
)
expect_equal(length(unlist(keras::get_weights(nnet_xy$fit))), 24)
- nnet_form_xy <- predict_num(nnet_xy, new_data = nn_dat[1:5, -(1:3)])
+ nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)])
expect_equal(ncol(nnet_form_xy), 3)
expect_equal(nrow(nnet_form_xy), 5)
expect_equal(names(nnet_form_xy), c("V1", "V2", "V3"))
diff --git a/tests/testthat/test_mlp_nnet.R b/tests/testthat/test_mlp_nnet.R
index 24dd3fa9e..a112086fe 100644
--- a/tests/testthat/test_mlp_nnet.R
+++ b/tests/testthat/test_mlp_nnet.R
@@ -141,7 +141,7 @@ test_that('nnet regression prediction', {
xy_pred <- predict(xy_fit$fit, newdata = mtcars[1:8, -1])[,1]
xy_pred <- unname(xy_pred)
- expect_equal(xy_pred, predict_num(xy_fit, new_data = mtcars[1:8, -1]))
+ expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1]))
form_fit <- fit(
car_basic,
@@ -152,7 +152,7 @@ test_that('nnet regression prediction', {
form_pred <- predict(form_fit$fit, newdata = mtcars[1:8, -1])[,1]
form_pred <- unname(form_pred)
- expect_equal(form_pred, predict_num(form_fit, new_data = mtcars[1:8, -1]))
+ expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1]))
})
# ------------------------------------------------------------------------------
@@ -175,7 +175,7 @@ test_that('multivariate nnet formula', {
data = nn_dat[-(1:5),]
)
expect_equal(length(nnet_form$fit$wts), 24)
- nnet_form_pred <- predict_num(nnet_form, new_data = nn_dat[1:5, -(1:3)])
+ nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)])
expect_equal(ncol(nnet_form_pred), 3)
expect_equal(nrow(nnet_form_pred), 5)
expect_equal(names(nnet_form_pred), c("V1", "V2", "V3"))
@@ -192,7 +192,7 @@ test_that('multivariate nnet formula', {
y = nn_dat[-(1:5), 1:3 ]
)
expect_equal(length(nnet_xy$fit$wts), 24)
- nnet_form_xy <- predict_num(nnet_xy, new_data = nn_dat[1:5, -(1:3)])
+ nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)])
expect_equal(ncol(nnet_form_xy), 3)
expect_equal(nrow(nnet_form_xy), 5)
expect_equal(names(nnet_form_xy), c("V1", "V2", "V3"))
diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R
index c658feb6f..517d5f28b 100644
--- a/tests/testthat/test_multinom_reg_glmnet.R
+++ b/tests/testthat/test_multinom_reg_glmnet.R
@@ -134,6 +134,12 @@ test_that('glmnet probabilities, mulitiple lambda', {
mult_class$.pred,
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams)$.pred
)
+
+ expect_error(
+ multi_predict(xy_fit, newdata = iris[rows, 1:4], penalty = lams),
+ "Did you mean"
+ )
+
})
diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R
new file mode 100644
index 000000000..faaa56c0f
--- /dev/null
+++ b/tests/testthat/test_multinom_reg_keras.R
@@ -0,0 +1,182 @@
+library(testthat)
+library(parsnip)
+library(rlang)
+library(tibble)
+library(dplyr)
+
+# ------------------------------------------------------------------------------
+
+context("keras logistic regression")
+source("helpers.R")
+
+# ------------------------------------------------------------------------------
+
+set.seed(352)
+dat <- iris[order(runif(nrow(iris))),]
+
+tr_dat <- dat[1:140, ]
+te_dat <- dat[141:150, ]
+
+# ------------------------------------------------------------------------------
+
+basic_mod <-
+ multinom_reg() %>%
+ set_engine("keras", epochs = 50, verbose = 0)
+
+reg_mod <-
+ multinom_reg(penalty = 0.1) %>%
+ set_engine("keras", epochs = 50, verbose = 0)
+
+ctrl <- fit_control(verbosity = 0, catch = FALSE)
+
+# ------------------------------------------------------------------------------
+
+test_that('model fitting', {
+
+ skip_if_not_installed("keras")
+
+ set.seed(257)
+ expect_error(
+ fit1 <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ ),
+ regexp = NA
+ )
+
+ set.seed(257)
+ expect_error(
+ fit2 <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ ),
+ regexp = NA
+ )
+ expect_equal(fit1, fit2)
+
+ expect_error(
+ fit(
+ basic_mod,
+ Species ~ .,
+ data = tr_dat,
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit1 <-
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ reg_mod,
+ Species ~ .,
+ data = tr_dat,
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('classification prediction', {
+
+ skip_if_not_installed("keras")
+
+ library(keras)
+
+ set.seed(257)
+ lr_fit <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ )
+
+ keras_raw <-
+ predict(lr_fit$fit, as.matrix(te_dat[, -5]))
+ keras_pred <-
+ tibble(.pred_class = apply(keras_raw, 1, which.max)) %>%
+ mutate(.pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl))
+
+ parsnip_pred <- predict(lr_fit, te_dat[, -5])
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+ set.seed(257)
+ plrfit <-
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ )
+
+ keras_raw <-
+ predict(plrfit$fit, as.matrix(te_dat[, -5]))
+ keras_pred <-
+ tibble(.pred_class = apply(keras_raw, 1, which.max)) %>%
+ mutate(.pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl))
+ parsnip_pred <- predict(plrfit, te_dat[, -5])
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+})
+
+
+test_that('classification probabilities', {
+
+ skip_if_not_installed("keras")
+
+ library(keras)
+
+ set.seed(257)
+ lr_fit <-
+ fit_xy(
+ basic_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ )
+
+ keras_pred <-
+ predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>%
+ as_tibble() %>%
+ setNames(paste0(".pred_", lr_fit$lvl))
+
+ parsnip_pred <- predict(lr_fit, te_dat[, -5], type = "prob")
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+ set.seed(257)
+ plrfit <-
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = tr_dat[, -5],
+ y = tr_dat$Species
+ )
+
+ keras_pred <-
+ predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>%
+ as_tibble() %>%
+ setNames(paste0(".pred_", lr_fit$lvl))
+ parsnip_pred <- predict(plrfit, te_dat[, -5], type = "prob")
+ expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
+
+})
+
+
diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R
index 8cdddd2dc..999c31842 100644
--- a/tests/testthat/test_nearest_neighbor_kknn.R
+++ b/tests/testthat/test_nearest_neighbor_kknn.R
@@ -74,7 +74,7 @@ test_that('kknn prediction', {
newdata = iris[1:5, num_pred]
)
- expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]))
+ expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]))
# nominal
res_xy_nom <- fit_xy(
@@ -105,5 +105,5 @@ test_that('kknn prediction', {
newdata = iris[1:5,]
)
- expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
+ expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")]))
})
diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R
index 4b4c24d77..cd10d2add 100644
--- a/tests/testthat/test_predict_formats.R
+++ b/tests/testthat/test_predict_formats.R
@@ -31,7 +31,7 @@ lr_fit_2 <-
test_that('regression predictions', {
expect_true(is_tibble(predict(lm_fit, new_data = iris[1:5,-1])))
- expect_true(is.vector(predict_num(lm_fit, new_data = iris[1:5,-1])))
+ expect_true(is.vector(predict_numeric(lm_fit, new_data = iris[1:5,-1])))
expect_equal(names(predict(lm_fit, new_data = iris[1:5,-1])), ".pred")
})
diff --git a/tests/testthat/test_rand_forest_randomForest.R b/tests/testthat/test_rand_forest_randomForest.R
index 9ae446edc..35937b244 100644
--- a/tests/testthat/test_rand_forest_randomForest.R
+++ b/tests/testthat/test_rand_forest_randomForest.R
@@ -209,6 +209,6 @@ test_that('randomForest regression prediction', {
xy_pred <- predict(xy_fit$fit, newdata = tail(mtcars))
xy_pred <- unname(xy_pred)
- expect_equal(xy_pred, predict_num(xy_fit, new_data = tail(mtcars)))
+ expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars)))
})
diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R
index 7e90bbc43..82d767c4b 100644
--- a/tests/testthat/test_rand_forest_ranger.R
+++ b/tests/testthat/test_rand_forest_ranger.R
@@ -229,7 +229,7 @@ test_that('ranger regression prediction', {
xy_pred <- predict(xy_fit$fit, data = tail(mtcars[, -1]))$prediction
- expect_equal(xy_pred, predict_num(xy_fit, new_data = tail(mtcars[, -1])))
+ expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars[, -1])))
})
diff --git a/tests/testthat/test_rand_forest_spark.R b/tests/testthat/test_rand_forest_spark.R
index bcda93626..4184e6abf 100644
--- a/tests/testthat/test_rand_forest_spark.R
+++ b/tests/testthat/test_rand_forest_spark.R
@@ -58,7 +58,7 @@ test_that('spark execution', {
)
expect_error(
- spark_reg_pred_num <- predict_num(spark_reg_fit, iris_rf_te),
+ spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_rf_te),
regexp = NA
)
@@ -68,7 +68,7 @@ test_that('spark execution', {
)
expect_error(
- spark_reg_num_dup <- predict_num(spark_reg_fit_dup, iris_rf_te),
+ spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_rf_te),
regexp = NA
)
diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R
new file mode 100644
index 000000000..6aea6acbf
--- /dev/null
+++ b/tests/testthat/test_svm_poly.R
@@ -0,0 +1,270 @@
+library(testthat)
+library(parsnip)
+library(rlang)
+library(tibble)
+
+# ------------------------------------------------------------------------------
+
+context("RBF SVM")
+source("helpers.R")
+
+# ------------------------------------------------------------------------------
+
+test_that('primary arguments', {
+ basic <- svm_poly()
+ basic_kernlab <- translate(basic %>% set_engine("kernlab"))
+
+ expect_equal(
+ object = basic_kernlab$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ kernel = "polydot"
+ )
+ )
+
+ degree <- svm_poly(degree = 2)
+ degree_kernlab <- translate(degree %>% set_engine("kernlab"))
+ degree_obj <- expr(list())
+ degree_obj$degree <- new_empty_quosure(2)
+
+ expect_equal(
+ object = degree_kernlab$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ kernel = "polydot",
+ kpar = degree_obj
+ )
+ )
+
+ degree_scale <- svm_poly(degree = 2, scale_factor = 1.2)
+ degree_scale_kernlab <- translate(degree_scale %>% set_engine("kernlab"))
+ degree_scale_obj <- expr(list())
+ degree_scale_obj$degree <- new_empty_quosure(2)
+ degree_scale_obj$scale <- new_empty_quosure(1.2)
+
+ expect_equal(
+ object = degree_scale_kernlab$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ kernel = "polydot",
+ kpar = degree_scale_obj
+ )
+ )
+
+})
+
+test_that('engine arguments', {
+
+ kernlab_cv <- svm_poly() %>% set_engine("kernlab", cross = 10)
+
+ expect_equal(
+ object = translate(kernlab_cv, "kernlab")$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ cross = new_empty_quosure(10),
+ kernel = "polydot"
+ )
+ )
+
+})
+
+
+test_that('updating', {
+
+ expr1 <- svm_poly() %>% set_engine("kernlab", cross = 10)
+ expr1_exp <- svm_poly(degree = 1) %>% set_engine("kernlab", cross = 10)
+
+ expr2 <- svm_poly(degree = varying()) %>% set_engine("kernlab")
+ expr2_exp <- svm_poly(degree = varying(), scale_factor = 1) %>% set_engine("kernlab")
+
+ expr3 <- svm_poly(degree = 2, scale_factor = varying()) %>% set_engine("kernlab")
+ expr3_exp <- svm_poly(degree = 3) %>% set_engine("kernlab")
+
+ expect_equal(update(expr1, degree = 1), expr1_exp)
+ expect_equal(update(expr2, scale_factor = 1), expr2_exp)
+ expect_equal(update(expr3, degree = 3, fresh = TRUE), expr3_exp)
+})
+
+test_that('bad input', {
+ expect_error(svm_poly(mode = "reallyunknown"))
+ expect_error(translate(svm_poly() %>% set_engine( NULL)))
+})
+
+# ------------------------------------------------------------------------------
+
+reg_mod <-
+ svm_poly(degree = 1, cost = 1/4) %>%
+ set_engine("kernlab") %>%
+ set_mode("regression")
+
+cls_mod <-
+ svm_poly(degree = 2, cost = 1/8) %>%
+ set_engine("kernlab") %>%
+ set_mode("classification")
+
+ctrl <- fit_control(verbosity = 0, catch = FALSE)
+
+# ------------------------------------------------------------------------------
+
+test_that('svm poly regression', {
+
+ skip_if_not_installed("kernlab")
+
+ expect_error(
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ reg_mod,
+ Sepal.Length ~ .,
+ data = iris[, -5],
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('svm poly regression prediction', {
+
+ skip_if_not_installed("kernlab")
+
+ library(kernlab)
+
+ reg_form <-
+ fit(
+ reg_mod,
+ Sepal.Length ~ .,
+ data = iris[, -5],
+ control = ctrl
+ )
+
+ # kern_pred <-
+ # predict(reg_form$fit, iris[1:3, -c(1, 5)]) %>%
+ # as_tibble() %>%
+ # setNames(".pred")
+ kern_pred <-
+ structure(
+ list(
+ .pred = c(5.02154233477783, 4.71496213707127, 4.78370369917621)),
+ row.names = c(NA,-3L),
+ class = c("tbl_df", "tbl", "data.frame")
+ )
+
+ parsnip_pred <- predict(reg_form, iris[1:3, -c(1, 5)])
+ expect_equal(as.data.frame(kern_pred), as.data.frame(parsnip_pred))
+
+
+ reg_xy_form <-
+ fit_xy(
+ reg_mod,
+ x = iris[, 2:4],
+ y = iris$Sepal.Length,
+ control = ctrl
+ )
+ expect_equal(reg_form$fit, reg_xy_form$fit)
+
+ parsnip_xy_pred <- predict(reg_xy_form, iris[1:3, -c(1, 5)])
+ expect_equal(as.data.frame(kern_pred), as.data.frame(parsnip_xy_pred))
+})
+
+# ------------------------------------------------------------------------------
+
+test_that('svm poly classification', {
+
+ skip_if_not_installed("kernlab")
+
+ expect_error(
+ fit_xy(
+ cls_mod,
+ control = ctrl,
+ x = iris[, -5],
+ y = iris$Species
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ cls_mod,
+ Species ~ .,
+ data = iris,
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('svm poly classification probabilities', {
+
+ skip_if_not_installed("kernlab")
+
+ library(kernlab)
+ ind <- c(1, 51, 101)
+
+ set.seed(34562)
+ cls_form <-
+ fit(
+ cls_mod,
+ Species ~ .,
+ data = iris,
+ control = ctrl
+ )
+
+ # kern_class <-
+ # tibble(.pred_class = predict(cls_form$fit, iris[ind, -5]))
+
+ kern_class <-
+ structure(
+ list(
+ .pred_class =
+ structure(1:3, .Label = c("setosa", "versicolor", "virginica"), class = "factor")),
+ row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame"))
+
+ parsnip_class <- predict(cls_form, iris[ind, -5])
+ expect_equal(kern_class, parsnip_class)
+
+ set.seed(34562)
+ cls_xy_form <-
+ fit_xy(
+ cls_mod,
+ x = iris[, 1:4],
+ y = iris$Species,
+ control = ctrl
+ )
+ expect_equal(cls_form$fit, cls_xy_form$fit)
+
+ # kern_probs <-
+ # predict(cls_form$fit, iris[ind, -5], type = "probabilities") %>%
+ # as_tibble() %>%
+ # setNames(c('.pred_setosa', '.pred_versicolor', '.pred_virginica'))
+
+ kern_probs <-
+ structure(
+ list(
+ .pred_setosa = c(0.982990083267231, 0.0167077303224448, 0.00930879923686657),
+ .pred_versicolor = c(0.00417116710624842, 0.946131931665357, 0.0015524073332013),
+ .pred_virginica = c(0.0128387496265202, 0.0371603380121978, 0.989138793429932)),
+ row.names = c(NA,-3L),
+ class = c("tbl_df", "tbl", "data.frame"))
+
+ parsnip_probs <- predict(cls_form, iris[ind, -5], type = "prob")
+ expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_probs))
+
+ parsnip_xy_probs <- predict(cls_xy_form, iris[ind, -5], type = "prob")
+ expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs))
+})
diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R
new file mode 100644
index 000000000..61261e0d0
--- /dev/null
+++ b/tests/testthat/test_svm_rbf.R
@@ -0,0 +1,245 @@
+library(testthat)
+library(parsnip)
+library(rlang)
+
+# ------------------------------------------------------------------------------
+
+context("poly SVM")
+source("helpers.R")
+
+# ------------------------------------------------------------------------------
+
+test_that('primary arguments', {
+ basic <- svm_rbf()
+ basic_kernlab <- translate(basic %>% set_engine("kernlab"))
+
+ expect_equal(
+ object = basic_kernlab$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ kernel = "rbfdot"
+ )
+ )
+
+ rbf_sigma <- svm_rbf(rbf_sigma = .2)
+ rbf_sigma_kernlab <- translate(rbf_sigma %>% set_engine("kernlab"))
+ rbf_sigma_obj <- expr(list())
+ rbf_sigma_obj$sigma <- new_empty_quosure(.2)
+
+ expect_equal(
+ object = rbf_sigma_kernlab$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ kernel = "rbfdot",
+ kpar = rbf_sigma_obj
+ )
+ )
+
+})
+
+test_that('engine arguments', {
+
+ kernlab_cv <- svm_rbf() %>% set_engine("kernlab", cross = 10)
+
+ expect_equal(
+ object = translate(kernlab_cv, "kernlab")$method$fit$args,
+ expected = list(
+ x = expr(missing_arg()),
+ y = expr(missing_arg()),
+ cross = new_empty_quosure(10),
+ kernel = "rbfdot"
+ )
+ )
+
+})
+
+
+test_that('updating', {
+
+ expr1 <- svm_rbf() %>% set_engine("kernlab", cross = 10)
+ expr1_exp <- svm_rbf(rbf_sigma = .1) %>% set_engine("kernlab", cross = 10)
+
+ expr3 <- svm_rbf(rbf_sigma = .2) %>% set_engine("kernlab")
+ expr3_exp <- svm_rbf(rbf_sigma = .3) %>% set_engine("kernlab")
+
+ expect_equal(update(expr1, rbf_sigma = .1), expr1_exp)
+ expect_equal(update(expr3, rbf_sigma = .3, fresh = TRUE), expr3_exp)
+})
+
+test_that('bad input', {
+ expect_error(svm_rbf(mode = "reallyunknown"))
+ expect_error(translate(svm_rbf() %>% set_engine( NULL)))
+})
+
+# ------------------------------------------------------------------------------
+
+reg_mod <-
+ svm_rbf(rbf_sigma = .1, cost = 1/4) %>%
+ set_engine("kernlab") %>%
+ set_mode("regression")
+
+cls_mod <-
+ svm_rbf(rbf_sigma = .1, cost = 1/8) %>%
+ set_engine("kernlab") %>%
+ set_mode("classification")
+
+ctrl <- fit_control(verbosity = 0, catch = FALSE)
+
+# ------------------------------------------------------------------------------
+
+test_that('svm poly regression', {
+
+ skip_if_not_installed("kernlab")
+
+ expect_error(
+ fit_xy(
+ reg_mod,
+ control = ctrl,
+ x = iris[,2:4],
+ y = iris$Sepal.Length
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ reg_mod,
+ Sepal.Length ~ .,
+ data = iris[, -5],
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('svm rbf regression prediction', {
+
+ skip_if_not_installed("kernlab")
+
+ library(kernlab)
+
+ reg_form <-
+ fit(
+ reg_mod,
+ Sepal.Length ~ .,
+ data = iris[, -5],
+ control = ctrl
+ )
+
+ # kern_pred <-
+ # predict(reg_form$fit, iris[1:3, -c(1, 5)]) %>%
+ # as_tibble() %>%
+ # setNames(".pred")
+ kern_pred <-
+ structure(
+ list(.pred = c(5.02786147259765, 4.81715220026091, 4.86817852816449)),
+ row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame"))
+
+ parsnip_pred <- predict(reg_form, iris[1:3, -c(1, 5)])
+ expect_equal(as.data.frame(kern_pred), as.data.frame(parsnip_pred))
+
+
+ reg_xy_form <-
+ fit_xy(
+ reg_mod,
+ x = iris[, 2:4],
+ y = iris$Sepal.Length,
+ control = ctrl
+ )
+ expect_equal(reg_form$fit, reg_xy_form$fit)
+
+ parsnip_xy_pred <- predict(reg_xy_form, iris[1:3, -c(1, 5)])
+ expect_equal(as.data.frame(kern_pred), as.data.frame(parsnip_xy_pred))
+})
+
+# ------------------------------------------------------------------------------
+
+test_that('svm rbf classification', {
+
+ skip_if_not_installed("kernlab")
+
+ expect_error(
+ fit_xy(
+ cls_mod,
+ control = ctrl,
+ x = iris[, -5],
+ y = iris$Species
+ ),
+ regexp = NA
+ )
+
+ expect_error(
+ fit(
+ cls_mod,
+ Species ~ .,
+ data = iris,
+ control = ctrl
+ ),
+ regexp = NA
+ )
+
+})
+
+
+test_that('svm rbf classification probabilities', {
+
+ skip_if_not_installed("kernlab")
+
+ library(kernlab)
+ ind <- c(1, 51, 101)
+
+ set.seed(34562)
+ cls_form <-
+ fit(
+ cls_mod,
+ Species ~ .,
+ data = iris,
+ control = ctrl
+ )
+
+ # kern_class <-
+ # tibble(.pred_class = predict(cls_form$fit, iris[ind, -5]))
+
+ kern_class <-
+ structure(list(
+ .pred_class = structure(
+ c(1L, 3L, 3L),
+ .Label = c("setosa", "versicolor", "virginica"), class = "factor")),
+ row.names = c(NA, -3L), class = c("tbl_df", "tbl", "data.frame"))
+
+ parsnip_class <- predict(cls_form, iris[ind, -5])
+ expect_equal(kern_class, parsnip_class)
+
+ set.seed(34562)
+ cls_xy_form <-
+ fit_xy(
+ cls_mod,
+ x = iris[, 1:4],
+ y = iris$Species,
+ control = ctrl
+ )
+ expect_equal(cls_form$fit, cls_xy_form$fit)
+
+ # kern_probs <-
+ # predict(cls_form$fit, iris[ind, -5], type = "probabilities") %>%
+ # as_tibble() %>%
+ # setNames(c('.pred_setosa', '.pred_versicolor', '.pred_virginica'))
+
+ kern_probs <-
+ structure(
+ list(
+ .pred_setosa = c(0.985403715135807, 0.0158818274678279, 0.00633995479908973),
+ .pred_versicolor = c(0.00818691538722139, 0.359005663318986, 0.0173471664171275),
+ .pred_virginica = c(0.00640936947697121, 0.625112509213187, 0.976312878783783)),
+ row.names = c(NA,-3L), class = c("tbl_df", "tbl", "data.frame"))
+
+ parsnip_probs <- predict(cls_form, iris[ind, -5], type = "prob")
+ expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_probs))
+
+ parsnip_xy_probs <- predict(cls_xy_form, iris[ind, -5], type = "prob")
+ expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_xy_probs))
+})
diff --git a/vignettes/articles/Models.Rmd b/vignettes/articles/Models.Rmd
index 35ea2f380..8fd627782 100644
--- a/vignettes/articles/Models.Rmd
+++ b/vignettes/articles/Models.Rmd
@@ -62,10 +62,8 @@ _How_ the model is created is related to the _engine_. In many cases, this is an
map2_dfr(engine_info$model, engine_info$engine, mod_names) %>%
dplyr::filter(!(module %in% c("libs", "fit"))) %>%
mutate(
- module = ifelse(module == "pred", "num", module),
module = ifelse(module == "confint", "conf_int", module),
module = ifelse(module == "predint", "pred_int", module),
- module = ifelse(module == "classes", "class", module),
module = paste0("`", module, "`"),
model = paste0("`", model, "()`"),
) %>%
@@ -76,5 +74,5 @@ _How_ the model is created is related to the _engine_. In many cases, this is an
collapse_rows(columns = 1)
```
-Models can be added by the user too. See the "Making a `parsnip` model from scratch" vignette.
+Models can be added by the user too. See the ["Making a `parsnip` model from scratch" vignette](Scratch.html).
diff --git a/vignettes/articles/Scratch.Rmd b/vignettes/articles/Scratch.Rmd
index 22d17a63f..ecccb90d6 100644
--- a/vignettes/articles/Scratch.Rmd
+++ b/vignettes/articles/Scratch.Rmd
@@ -155,27 +155,27 @@ mixture_da_mda_data$fit <-
)
```
-### The `pred` module
+### The `numeric` module
-This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For `pred`, the model requires an unnamed numeric vector output (usually).
+This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For `numeric`, the model requires an unnamed numeric vector output (usually).
Examples are [here](https://github.com/topepo/parsnip/blob/master/R/linear_reg_data.R) and [here](https://github.com/topepo/parsnip/blob/master/R/rand_forest_data.R).
For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results).
-Note that the `pred` module maps to the `predict_num` function in `parsnip`. However, the user-facing `predict` function is used to generate predictions and returns a tibble with a column named `.pred` (see the example below). When creating new models, you don't have to write code for that part.
+Note that the `numeric` module maps to the `predict_numeric` function in `parsnip`. However, the user-facing `predict` function is used to generate predictions and returns a tibble with a column named `.pred` (see the example below). When creating new models, you don't have to write code for that part.
-### The `classes` module
+### The `class` module
-To make hard class predictions, the `classes` object contains the details. The elements of the list are:
+To make hard class predictions, the `class` object contains the details. The elements of the list are:
* `pre` and `post` are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won't be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a `pre` function with the `args` below.
* `func` is the prediction function (in the same format as above). In many cases, packages have a predict method for their model's class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to `predict` with no associated package.
* `args` is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in `rlang::expr` so that they are not evaluated when defining the method. For `mda`, the code would be `predict(object, newdata, type = "class")`. What is actually given to the function is the `parsnip` model fit object, which includes a sub-object called `fit` and this houses the `mda` model object. If the data need to be a matrix or data frame, you could also use `new_data = quote(as.data.frame(new_data))` and so on.
-```{r mda-classes}
-mixture_da_mda_data$classes <-
+```{r mda-class}
+mixture_da_mda_data$class <-
list(
pre = NULL,
post = NULL,
@@ -196,14 +196,14 @@ mixture_da_mda_data$classes <-
The `predict_class` function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the `pred` module, the user doesn't call `predict_class` but uses `predict` instead and this produces a tibble with a column named `.pred_class` [per the model guidlines](https://tidymodels.github.io/model-implementation-principles/model-predictions.html#return-values).
-### The `prob` module
+### The `classprob` module
-This defines the class probabilities (if they can be computed). The format is identical to the `classes` module but the output is expected to be a tibble with columns for each factor level.
+This defines the class probabilities (if they can be computed). The format is identical to the `class` module but the output is expected to be a tibble with columns for each factor level.
As an example of the `post` function, the data frame created by `mda:::predict.mda` will be converted to a tibble. The arguments are `x` (the raw results coming from the predict method) and `object` (the `parsnip` model fit object). The latter has a sub-object called `lvl` which is a character string of the outcome's factor levels (if any).
-```{r mda-prob}
-mixture_da_mda_data$prob <-
+```{r mda-classprob}
+mixture_da_mda_data$classprob <-
list(
pre = NULL,
post = function(x, object) {
@@ -263,9 +263,9 @@ There are various things that came to mind while writing this document.
### Do I have to return a simple vector for `predict_num` and `predict_class`?
-Previously, when discussing the `pred` information:
+Previously, when discussing the `numeric` information:
-> For `pred`, the model requires an unnamed numeric vector output **(usually)**.
+> For `numeric`, the model requires an unnamed numeric vector output **(usually)**.
There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated [here](https://tidymodels.github.io/model-implementation-principles/notes.html#list-cols) but some easy examples are: