From 43853329d1993162958218ea2320e54f2dfca92d Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 25 Oct 2018 09:52:47 -0400 Subject: [PATCH 01/10] closes #89 --- R/boost_tree.R | 6 ++++++ R/linear_reg.R | 12 ++++++++++++ R/logistic_reg.R | 15 +++++++++++++++ R/mars.R | 3 +++ R/multinom_reg.R | 3 +++ R/predict.R | 3 +++ tests/testthat/test_boost_tree_C50.R | 5 +++++ tests/testthat/test_boost_tree_xgboost.R | 5 +++++ tests/testthat/test_linear_reg.R | 11 +++++++++++ tests/testthat/test_linear_reg_glmnet.R | 5 +++++ tests/testthat/test_logistic_reg_glmnet.R | 5 +++++ tests/testthat/test_mars.R | 5 +++++ tests/testthat/test_multinom_reg_glmnet.R | 6 ++++++ 13 files changed, 84 insertions(+) diff --git a/R/boost_tree.R b/R/boost_tree.R index f196d4e8e..12c0b5758 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -359,6 +359,9 @@ xgb_pred <- function(object, newdata, ...) { #' @export multi_predict._xgb.Booster <- function(object, new_data, type = NULL, trees = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is.null(trees)) trees <- object$fit$nIter trees <- sort(trees) @@ -458,6 +461,9 @@ C5.0_train <- #' @export multi_predict._C5.0 <- function(object, new_data, type = NULL, trees = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is.null(trees)) trees <- min(object$fit$trials) trees <- sort(trees) diff --git a/R/linear_reg.R b/R/linear_reg.R index e0805d288..fb012b26a 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -211,18 +211,27 @@ organize_glmnet_pred <- function(x, object) { #' @export predict._elnet <- function(object, new_data, type = NULL, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } #' @export predict_num._elnet <- function(object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_num.model_fit(object, new_data = new_data, ...) } #' @export predict_raw._elnet <- function(object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -232,6 +241,9 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) { #' @export multi_predict._elnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + dots <- list(...) if (is.null(penalty)) penalty <- object$fit$lambda diff --git a/R/logistic_reg.R b/R/logistic_reg.R index a0d67f0c1..931772a60 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -230,24 +230,36 @@ organize_glmnet_prob <- function(x, object) { #' @export predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) } #' @export predict_class._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_class.model_fit(object, new_data = new_data, ...) } #' @export predict_classprob._lognet <- function (object, new_data, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_classprob.model_fit(object, new_data = new_data, ...) } #' @export predict_raw._lognet <- function (object, new_data, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + object$spec <- eval_args(object$spec) predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } @@ -258,6 +270,9 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) { #' @export multi_predict._lognet <- function(object, new_data, type = NULL, penalty = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + dots <- list(...) if (is.null(penalty)) penalty <- object$lambda diff --git a/R/mars.R b/R/mars.R index 7835eb05b..f1f1c8151 100644 --- a/R/mars.R +++ b/R/mars.R @@ -206,6 +206,9 @@ earth_reg_updater <- function(num, object, new_data, ...) { #' @export multi_predict._earth <- function(object, new_data, type = NULL, num_terms = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is.null(num_terms)) num_terms <- object$fit$selected.terms[-1] diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 6f6a41b43..cb5b17785 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -236,6 +236,9 @@ predict._multnet <- #' @export multi_predict._multnet <- function(object, new_data, type = NULL, penalty = NULL, ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + if (is_quosure(penalty)) penalty <- eval_tidy(penalty) diff --git a/R/predict.R b/R/predict.R index ea7ea7149..87750a726 100644 --- a/R/predict.R +++ b/R/predict.R @@ -91,6 +91,9 @@ #' @export predict.model_fit #' @export predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...) { + if (any(names(enquos(...)) == "newdata")) + stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) + type <- check_pred_type(object, type) if (type != "raw" && length(opts) > 0) warning("`opts` is only used with `type = 'raw'` and was ignored.") diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R index 5c3e6096e..c1f189d7c 100644 --- a/tests/testthat/test_boost_tree_C50.R +++ b/tests/testthat/test_boost_tree_C50.R @@ -121,5 +121,10 @@ test_that('submodel prediction', { mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 4, type = "prob") mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_No"]], unname(pred_class[, "No"])) + + expect_error( + multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 4, type = "prob"), + "Did you mean" + ) }) diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index f8a7f7aa1..5ff9e481f 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -188,5 +188,10 @@ test_that('submodel prediction', { mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 5, type = "prob") mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_No"]], pred_class) + + expect_error( + multi_predict(class_fit, newdata = wa_churn[1:4, vars], trees = 5, type = "prob"), + "Did you mean" + ) }) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 9df468114..27737194b 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -322,3 +322,14 @@ test_that('lm intervals', { expect_equivalent(prediction_parsnip$.pred_upper, prediction_lm[, "upr"]) }) + +test_that('newdata error trapping', { + res_xy <- fit_xy( + iris_basic, + x = iris[, num_pred], + y = iris$Sepal.Length, + control = ctrl + ) + expect_error(predict(res_xy, newdata = iris[1:3, num_pred]), "Did you mean") +}) + diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index ba7458038..70abc8cf2 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -198,5 +198,10 @@ test_that('submodel prediction', { mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], penalty = .1) mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred"]], unname(pred_glmn[,1])) + + expect_error( + multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), + "Did you mean" + ) }) diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index 9b751cf46..ed4d0ff77 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -331,6 +331,11 @@ test_that('submodel prediction', { mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], penalty = .1, type = "prob") mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_No"]], unname(pred_glmn[,1])) + + expect_error( + multi_predict(class_fit, newdata = wa_churn[1:4, vars], penalty = .1, type = "prob"), + "Did you mean" + ) }) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index bf3d35c11..a2e405fd6 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -242,6 +242,11 @@ test_that('submodel prediction', { mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], num_terms = 5, type = "prob") mp_res <- do.call("rbind", mp_res$.pred) expect_equal(mp_res[[".pred_No"]], pruned_cls_pred) + + expect_error( + multi_predict(reg_fit, newdata = mtcars[1:4, -1], num_terms = 5), + "Did you mean" + ) }) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index c658feb6f..517d5f28b 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -134,6 +134,12 @@ test_that('glmnet probabilities, mulitiple lambda', { mult_class$.pred, multi_predict(xy_fit, iris[rows, 1:4], penalty = lams)$.pred ) + + expect_error( + multi_predict(xy_fit, newdata = iris[rows, 1:4], penalty = lams), + "Did you mean" + ) + }) From 8005421faad20f1ce0ee05fa02d6ef602b6644a8 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 25 Oct 2018 10:16:17 -0400 Subject: [PATCH 02/10] num -> numeric, prob -> classprob for #65 --- NAMESPACE | 8 +++--- R/boost_tree.R | 2 +- R/boost_tree_data.R | 10 +++---- R/linear_reg.R | 4 +-- R/linear_reg_data.R | 8 +++--- R/logistic_reg_data.R | 8 +++--- R/mars_data.R | 4 +-- R/mlp_data.R | 6 ++-- R/multinom_reg_data.R | 4 +-- R/nearest_neighbor_data.R | 6 ++-- R/predict.R | 2 +- R/predict_classprob.R | 12 ++++---- R/{predict_num.R => predict_numeric.R} | 28 +++++++++---------- R/rand_forest_data.R | 12 ++++---- R/surv_reg_data.R | 6 ++-- man/other_predict.Rd | 10 +++---- tests/testthat/test_boost_tree_spark.R | 4 +-- tests/testthat/test_boost_tree_xgboost.R | 4 +-- tests/testthat/test_linear_reg.R | 6 ++-- tests/testthat/test_linear_reg_glmnet.R | 12 ++++---- tests/testthat/test_linear_reg_spark.R | 2 +- tests/testthat/test_linear_reg_stan.R | 4 +-- tests/testthat/test_mars.R | 6 ++-- tests/testthat/test_mlp_keras.R | 4 +-- tests/testthat/test_mlp_nnet.R | 8 +++--- tests/testthat/test_nearest_neighbor_kknn.R | 4 +-- tests/testthat/test_predict_formats.R | 2 +- .../testthat/test_rand_forest_randomForest.R | 2 +- tests/testthat/test_rand_forest_ranger.R | 2 +- tests/testthat/test_rand_forest_spark.R | 4 +-- 30 files changed, 97 insertions(+), 97 deletions(-) rename R/{predict_num.R => predict_numeric.R} (50%) diff --git a/NAMESPACE b/NAMESPACE index 7d4305157..06caa62cb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -19,8 +19,8 @@ S3method(predict_classprob,"_lognet") S3method(predict_classprob,"_multnet") S3method(predict_classprob,model_fit) S3method(predict_confint,model_fit) -S3method(predict_num,"_elnet") -S3method(predict_num,model_fit) +S3method(predict_numeric,"_elnet") +S3method(predict_numeric,model_fit) S3method(predict_predint,model_fit) S3method(predict_quantile,model_fit) S3method(predict_raw,"_elnet") @@ -92,8 +92,8 @@ export(predict_classprob) export(predict_classprob.model_fit) export(predict_confint) export(predict_confint.model_fit) -export(predict_num) -export(predict_num.model_fit) +export(predict_numeric) +export(predict_numeric.model_fit) export(predict_predint) export(predict_predint.model_fit) export(predict_quantile) diff --git a/R/boost_tree.R b/R/boost_tree.R index 12c0b5758..8c31419ff 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -394,7 +394,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { pred <- boost_tree_xgboost_data$classes$post(pred, object) pred <- tibble(.pred = factor(pred, levels = object$lvl)) } else { - pred <- boost_tree_xgboost_data$prob$post(pred, object) + pred <- boost_tree_xgboost_data$classprob$post(pred, object) pred <- as_tibble(pred) names(pred) <- paste0(".pred_", names(pred)) } diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 206b78e20..20e238aaa 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -31,7 +31,7 @@ boost_tree_xgboost_data <- verbose = 0 ) ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "xgb_pred"), @@ -58,7 +58,7 @@ boost_tree_xgboost_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { if (is.vector(x)) { @@ -106,7 +106,7 @@ boost_tree_C5.0_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { as_tibble(x) @@ -142,7 +142,7 @@ boost_tree_spark_data <- seed = expr(sample.int(10^5, 1)) ) ), - pred = list( + numeric = list( pre = NULL, post = format_spark_num, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -162,7 +162,7 @@ boost_tree_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/linear_reg.R b/R/linear_reg.R index fb012b26a..633088e7f 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -219,12 +219,12 @@ predict._elnet <- } #' @export -predict_num._elnet <- function(object, new_data, ...) { +predict_numeric._elnet <- function(object, new_data, ...) { if (any(names(enquos(...)) == "newdata")) stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE) object$spec <- eval_args(object$spec) - predict_num.model_fit(object, new_data = new_data, ...) + predict_numeric.model_fit(object, new_data = new_data, ...) } #' @export diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 57aebfd02..750212634 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -30,7 +30,7 @@ linear_reg_lm_data <- func = c(pkg = "stats", fun = "lm"), defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -100,7 +100,7 @@ linear_reg_glmnet_data <- family = "gaussian" ) ), - pred = list( + numeric = list( pre = NULL, post = organize_glmnet_pred, func = c(fun = "predict"), @@ -135,7 +135,7 @@ linear_reg_stan_data <- family = expr(stats::gaussian) ) ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -224,7 +224,7 @@ linear_reg_spark_data <- protect = c("x", "formula", "weight_col"), func = c(pkg = "sparklyr", fun = "ml_linear_regression") ), - pred = list( + numeric = list( pre = NULL, post = function(results, object) { results <- dplyr::rename(results, pred = prediction) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 972add371..a39dc137c 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -44,7 +44,7 @@ logistic_reg_glm_data <- type = "response" ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- tibble(v1 = 1 - x, v2 = x) @@ -121,7 +121,7 @@ logistic_reg_glmnet_data <- s = quote(object$spec$args$penalty) ) ), - prob = list( + classprob = list( pre = NULL, post = organize_glmnet_prob, func = c(fun = "predict"), @@ -170,7 +170,7 @@ logistic_reg_stan_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- object$fit$family$linkinv(x) @@ -278,7 +278,7 @@ logistic_reg_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/mars_data.R b/R/mars_data.R index 83addb849..e67c5e3b2 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -23,7 +23,7 @@ mars_earth_data <- func = c(pkg = "earth", fun = "earth"), defaults = list(keepxy = TRUE) ), - pred = list( + numeric = list( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), @@ -48,7 +48,7 @@ mars_earth_data <- type = "response" ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- x[,1] diff --git a/R/mlp_data.R b/R/mlp_data.R index 5e5ccd3f8..d2d6c472f 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -25,7 +25,7 @@ mlp_keras_data <- func = c(pkg = "parsnip", fun = "keras_mlp"), defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), @@ -47,7 +47,7 @@ mlp_keras_data <- x = quote(as.matrix(new_data)) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { x <- as_tibble(x) @@ -81,7 +81,7 @@ mlp_nnet_data <- func = c(pkg = "nnet", fun = "nnet"), defaults = list(trace = FALSE) ), - pred = list( + numeric = list( pre = NULL, post = maybe_multivariate, func = c(fun = "predict"), diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 921e6a0dc..b32bacd7e 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -40,7 +40,7 @@ multinom_reg_glmnet_data <- s = quote(object$spec$args$penalty) ) ), - prob = list( + classprob = list( pre = check_glmnet_lambda, post = organize_multnet_prob, func = c(fun = "predict"), @@ -85,7 +85,7 @@ multinom_reg_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index c6106561c..9b322981e 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -22,8 +22,8 @@ nearest_neighbor_kknn_data <- func = c(pkg = "kknn", fun = "train.kknn"), defaults = list() ), - pred = list( - # seems unnecessary here as the predict_num catches it based on the + numeric = list( + # seems unnecessary here as the predict_numeric catches it based on the # model mode pre = function(x, object) { if (object$fit$response != "continuous") { @@ -60,7 +60,7 @@ nearest_neighbor_kknn_data <- type = "raw" ) ), - prob = list( + classprob = list( pre = function(x, object) { if (!(object$fit$response %in% c("ordinal", "nominal"))) { stop("`kknn` model does not appear to use class predictions. Was ", diff --git a/R/predict.R b/R/predict.R index 87750a726..947b8ed72 100644 --- a/R/predict.R +++ b/R/predict.R @@ -99,7 +99,7 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ... warning("`opts` is only used with `type = 'raw'` and was ignored.") res <- switch( type, - numeric = predict_num(object = object, new_data = new_data, ...), + numeric = predict_numeric(object = object, new_data = new_data, ...), class = predict_class(object = object, new_data = new_data, ...), prob = predict_classprob(object = object, new_data = new_data, ...), conf_int = predict_confint(object = object, new_data = new_data, ...), diff --git a/R/predict_classprob.R b/R/predict_classprob.R index 113015fe8..8f2f79b8d 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -10,23 +10,23 @@ predict_classprob.model_fit <- function (object, new_data, ...) { stop("`predict.model_fit` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "prob")) + if (!any(names(object$spec$method) == "classprob")) stop("No class probability module defined for this model.", call. = FALSE) new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$prob$pre)) - new_data <- object$spec$method$prob$pre(new_data, object) + if (!is.null(object$spec$method$classprob$pre)) + new_data <- object$spec$method$classprob$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$prob) + pred_call <- make_pred_call(object$spec$method$classprob) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$prob$post)) { - res <- object$spec$method$prob$post(res, object) + if(!is.null(object$spec$method$classprob$post)) { + res <- object$spec$method$classprob$post(res, object) } # check and sort names diff --git a/R/predict_num.R b/R/predict_numeric.R similarity index 50% rename from R/predict_num.R rename to R/predict_numeric.R index 243a73890..5b9750146 100644 --- a/R/predict_num.R +++ b/R/predict_numeric.R @@ -1,33 +1,33 @@ #' @keywords internal #' @rdname other_predict #' @inheritParams predict.model_fit -#' @method predict_num model_fit -#' @export predict_num.model_fit +#' @method predict_numeric model_fit +#' @export predict_numeric.model_fit #' @export -# TODO add ... -predict_num.model_fit <- function (object, new_data, ...) { + +predict_numeric.model_fit <- function (object, new_data, ...) { if (object$spec$mode != "regression") - stop("`predict_num` is for predicting numeric outcomes. ", + stop("`predict_numeric` is for predicting numeric outcomes. ", "Use `predict_class` or `predict_prob` for ", "classification models.", call. = FALSE) - if (!any(names(object$spec$method) == "pred")) + if (!any(names(object$spec$method) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$pre)) - new_data <- object$spec$method$pred$pre(new_data, object) + if (!is.null(object$spec$method$numeric$pre)) + new_data <- object$spec$method$numeric$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$pred) + pred_call <- make_pred_call(object$spec$method$numeric) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$pred$post)) { - res <- object$spec$method$pred$post(res, object) + if (!is.null(object$spec$method$numeric$post)) { + res <- object$spec$method$numeric$post(res, object) } if (is.vector(res)) { @@ -43,6 +43,6 @@ predict_num.model_fit <- function (object, new_data, ...) { #' @export #' @keywords internal #' @rdname other_predict -#' @inheritParams predict_num.model_fit -predict_num <- function (object, ...) - UseMethod("predict_num") +#' @inheritParams predict_numeric.model_fit +predict_numeric <- function (object, ...) + UseMethod("predict_numeric") diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 3db4a439a..5f95de5c2 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -110,7 +110,7 @@ rand_forest_ranger_data <- seed = expr(sample.int(10^5, 1)) ) ), - pred = list( + numeric = list( pre = NULL, post = function(results, object) results$predictions, func = c(fun = "predict"), @@ -136,7 +136,7 @@ rand_forest_ranger_data <- verbose = FALSE ) ), - prob = list( + classprob = list( pre = function(x, object) { if (object$fit$forest$treetype != "Probability estimation") stop("`ranger` model does not appear to use class probabilities. Was ", @@ -190,7 +190,7 @@ rand_forest_randomForest_data <- defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -210,7 +210,7 @@ rand_forest_randomForest_data <- newdata = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = function(x, object) { as_tibble(as.data.frame(x)) @@ -247,7 +247,7 @@ rand_forest_spark_data <- seed = expr(sample.int(10^5, 1)) ) ), - pred = list( + numeric = list( pre = NULL, post = format_spark_num, func = c(pkg = "sparklyr", fun = "ml_predict"), @@ -267,7 +267,7 @@ rand_forest_spark_data <- dataset = quote(new_data) ) ), - prob = list( + classprob = list( pre = NULL, post = format_spark_probs, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 43f55cecb..3b030717b 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -26,7 +26,7 @@ surv_reg_flexsurv_data <- func = c(pkg = "flexsurv", fun = "flexsurvreg"), defaults = list() ), - pred = list( + numeric = list( pre = NULL, post = flexsurv_mean, func = c(fun = "summary"), @@ -62,7 +62,7 @@ surv_reg_survreg_data <- func = c(pkg = "survival", fun = "survreg"), defaults = list(model = TRUE) ), - pred = list( + numeric = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -101,7 +101,7 @@ surv_reg_survreg_data <- # seed = expr(sample.int(10^5, 1)) # ) # ), -# pred = list( +# numeric = list( # pre = NULL, # post = function(results, object) { # tibble::as_tibble(results) %>% diff --git a/man/other_predict.Rd b/man/other_predict.Rd index f462f4d0b..57e3bf3f2 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -1,6 +1,6 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/predict_class.R, R/predict_classprob.R, -% R/predict_interval.R, R/predict_num.R, R/predict_quantile.R +% R/predict_interval.R, R/predict_numeric.R, R/predict_quantile.R \name{predict_class.model_fit} \alias{predict_class.model_fit} \alias{predict_class} @@ -10,8 +10,8 @@ \alias{predict_confint} \alias{predict_predint.model_fit} \alias{predict_predint} -\alias{predict_num.model_fit} -\alias{predict_num} +\alias{predict_numeric.model_fit} +\alias{predict_numeric} \alias{predict_quantile.model_fit} \alias{predict_quantile} \title{Other predict methods.} @@ -34,9 +34,9 @@ predict_confint(object, ...) predict_predint(object, ...) -\method{predict_num}{model_fit}(object, new_data, ...) +\method{predict_numeric}{model_fit}(object, new_data, ...) -predict_num(object, ...) +predict_numeric(object, ...) \method{predict_quantile}{model_fit}(object, new_data, quantile = (1:9)/10, ...) diff --git a/tests/testthat/test_boost_tree_spark.R b/tests/testthat/test_boost_tree_spark.R index 6c517e3b8..ce148620c 100644 --- a/tests/testthat/test_boost_tree_spark.R +++ b/tests/testthat/test_boost_tree_spark.R @@ -58,7 +58,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_pred_num <- predict_num(spark_reg_fit, iris_bt_te), + spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_bt_te), regexp = NA ) @@ -68,7 +68,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_num_dup <- predict_num(spark_reg_fit_dup, iris_bt_te), + spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_bt_te), regexp = NA ) diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 5ff9e481f..0c6be9417 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -141,7 +141,7 @@ test_that('xgboost regression prediction', { ) xy_pred <- predict(xy_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) - expect_equal(xy_pred, predict_num(xy_fit, new_data = mtcars[1:8, -1])) + expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) form_fit <- fit( car_basic, @@ -151,7 +151,7 @@ test_that('xgboost regression prediction', { ) form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1]))) - expect_equal(form_pred, predict_num(form_fit, new_data = mtcars[1:8, -1])) + expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1])) }) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index 27737194b..86d2e2bed 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -269,7 +269,7 @@ test_that('lm prediction', { control = ctrl ) - expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -277,7 +277,7 @@ test_that('lm prediction', { data = iris, control = ctrl ) - expect_equal(inl_pred, predict_num(res_form, iris[1:5, ])) + expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) res_mv <- fit( iris_basic, @@ -285,7 +285,7 @@ test_that('lm prediction', { data = iris, control = ctrl ) - expect_equal(mv_pred, predict_num(res_mv, iris[1:5,])) + expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,])) }) test_that('lm intervals', { diff --git a/tests/testthat/test_linear_reg_glmnet.R b/tests/testthat/test_linear_reg_glmnet.R index 70abc8cf2..a22b6e73d 100644 --- a/tests/testthat/test_linear_reg_glmnet.R +++ b/tests/testthat/test_linear_reg_glmnet.R @@ -69,7 +69,7 @@ test_that('glmnet prediction, single lambda', { s = iris_basic$spec$args$penalty) uni_pred <- unname(uni_pred[,1]) - expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -87,7 +87,7 @@ test_that('glmnet prediction, single lambda', { s = res_form$spec$spec$args$penalty) form_pred <- unname(form_pred[,1]) - expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) @@ -115,7 +115,7 @@ test_that('glmnet prediction, multiple lambda', { mult_pred$lambda <- rep(lams, each = 5) mult_pred <- mult_pred[,-2] - expect_equal(mult_pred, predict_num(res_xy, iris[1:5, num_pred])) + expect_equal(mult_pred, predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_mult, @@ -135,7 +135,7 @@ test_that('glmnet prediction, multiple lambda', { form_pred$lambda <- rep(lams, each = 5) form_pred <- form_pred[,-2] - expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) test_that('glmnet prediction, all lambda', { @@ -157,7 +157,7 @@ test_that('glmnet prediction, all lambda', { all_pred$lambda <- rep(res_xy$fit$lambda, each = 5) all_pred <- all_pred[,-2] - expect_equal(all_pred, predict_num(res_xy, iris[1:5, num_pred])) + expect_equal(all_pred, predict_numeric(res_xy, iris[1:5, num_pred])) # test that the lambda seq is in the right order (since no docs on this) tmp_pred <- predict(res_xy$fit, newx = as.matrix(iris[1:5, num_pred]), @@ -180,7 +180,7 @@ test_that('glmnet prediction, all lambda', { form_pred$lambda <- rep(res_form$fit$lambda, each = 5) form_pred <- form_pred[,-2] - expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) diff --git a/tests/testthat/test_linear_reg_spark.R b/tests/testthat/test_linear_reg_spark.R index 2ed55d7d6..f1a033beb 100644 --- a/tests/testthat/test_linear_reg_spark.R +++ b/tests/testthat/test_linear_reg_spark.R @@ -42,7 +42,7 @@ test_that('spark execution', { ) expect_error( - spark_pred_num <- predict_num(spark_fit, iris_linreg_te), + spark_pred_num <- predict_numeric(spark_fit, iris_linreg_te), regexp = NA ) diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R index 4f7288d90..bf7bd53de 100644 --- a/tests/testthat/test_linear_reg_stan.R +++ b/tests/testthat/test_linear_reg_stan.R @@ -70,7 +70,7 @@ test_that('stan prediction', { control = quiet_ctrl ) - expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]), tolerance = 0.001) + expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred]), tolerance = 0.001) res_form <- fit( iris_basic, @@ -78,7 +78,7 @@ test_that('stan prediction', { data = iris, control = quiet_ctrl ) - expect_equal(inl_pred, predict_num(res_form, iris[1:5, ]), tolerance = 0.001) + expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ]), tolerance = 0.001) }) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index a2e405fd6..ed3505ceb 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -180,7 +180,7 @@ test_that('mars prediction', { control = ctrl ) - expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) res_form <- fit( iris_basic, @@ -188,7 +188,7 @@ test_that('mars prediction', { data = iris, control = ctrl ) - expect_equal(inl_pred, predict_num(res_form, iris[1:5, ])) + expect_equal(inl_pred, predict_numeric(res_form, iris[1:5, ])) res_mv <- fit( iris_basic, @@ -196,7 +196,7 @@ test_that('mars prediction', { data = iris, control = ctrl ) - expect_equal(mv_pred, predict_num(res_mv, iris[1:5,])) + expect_equal(mv_pred, predict_numeric(res_mv, iris[1:5,])) }) diff --git a/tests/testthat/test_mlp_keras.R b/tests/testthat/test_mlp_keras.R index 9e1a55448..09ae0cb61 100644 --- a/tests/testthat/test_mlp_keras.R +++ b/tests/testthat/test_mlp_keras.R @@ -218,7 +218,7 @@ test_that('multivariate nnet formula', { data = nn_dat[-(1:5),] ) expect_equal(length(unlist(keras::get_weights(nnet_form$fit))), 24) - nnet_form_pred <- predict_num(nnet_form, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_pred), 3) expect_equal(nrow(nnet_form_pred), 5) expect_equal(names(nnet_form_pred), c("V1", "V2", "V3")) @@ -233,7 +233,7 @@ test_that('multivariate nnet formula', { y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(unlist(keras::get_weights(nnet_xy$fit))), 24) - nnet_form_xy <- predict_num(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_xy), 3) expect_equal(nrow(nnet_form_xy), 5) expect_equal(names(nnet_form_xy), c("V1", "V2", "V3")) diff --git a/tests/testthat/test_mlp_nnet.R b/tests/testthat/test_mlp_nnet.R index 24dd3fa9e..a112086fe 100644 --- a/tests/testthat/test_mlp_nnet.R +++ b/tests/testthat/test_mlp_nnet.R @@ -141,7 +141,7 @@ test_that('nnet regression prediction', { xy_pred <- predict(xy_fit$fit, newdata = mtcars[1:8, -1])[,1] xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_num(xy_fit, new_data = mtcars[1:8, -1])) + expect_equal(xy_pred, predict_numeric(xy_fit, new_data = mtcars[1:8, -1])) form_fit <- fit( car_basic, @@ -152,7 +152,7 @@ test_that('nnet regression prediction', { form_pred <- predict(form_fit$fit, newdata = mtcars[1:8, -1])[,1] form_pred <- unname(form_pred) - expect_equal(form_pred, predict_num(form_fit, new_data = mtcars[1:8, -1])) + expect_equal(form_pred, predict_numeric(form_fit, new_data = mtcars[1:8, -1])) }) # ------------------------------------------------------------------------------ @@ -175,7 +175,7 @@ test_that('multivariate nnet formula', { data = nn_dat[-(1:5),] ) expect_equal(length(nnet_form$fit$wts), 24) - nnet_form_pred <- predict_num(nnet_form, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_pred <- predict_numeric(nnet_form, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_pred), 3) expect_equal(nrow(nnet_form_pred), 5) expect_equal(names(nnet_form_pred), c("V1", "V2", "V3")) @@ -192,7 +192,7 @@ test_that('multivariate nnet formula', { y = nn_dat[-(1:5), 1:3 ] ) expect_equal(length(nnet_xy$fit$wts), 24) - nnet_form_xy <- predict_num(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) + nnet_form_xy <- predict_numeric(nnet_xy, new_data = nn_dat[1:5, -(1:3)]) expect_equal(ncol(nnet_form_xy), 3) expect_equal(nrow(nnet_form_xy), 5) expect_equal(names(nnet_form_xy), c("V1", "V2", "V3")) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 8cdddd2dc..999c31842 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -74,7 +74,7 @@ test_that('kknn prediction', { newdata = iris[1:5, num_pred] ) - expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred])) + expect_equal(uni_pred, predict_numeric(res_xy, iris[1:5, num_pred])) # nominal res_xy_nom <- fit_xy( @@ -105,5 +105,5 @@ test_that('kknn prediction', { newdata = iris[1:5,] ) - expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")])) + expect_equal(form_pred, predict_numeric(res_form, iris[1:5, c("Sepal.Width", "Species")])) }) diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index 4b4c24d77..cd10d2add 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -31,7 +31,7 @@ lr_fit_2 <- test_that('regression predictions', { expect_true(is_tibble(predict(lm_fit, new_data = iris[1:5,-1]))) - expect_true(is.vector(predict_num(lm_fit, new_data = iris[1:5,-1]))) + expect_true(is.vector(predict_numeric(lm_fit, new_data = iris[1:5,-1]))) expect_equal(names(predict(lm_fit, new_data = iris[1:5,-1])), ".pred") }) diff --git a/tests/testthat/test_rand_forest_randomForest.R b/tests/testthat/test_rand_forest_randomForest.R index 9ae446edc..35937b244 100644 --- a/tests/testthat/test_rand_forest_randomForest.R +++ b/tests/testthat/test_rand_forest_randomForest.R @@ -209,6 +209,6 @@ test_that('randomForest regression prediction', { xy_pred <- predict(xy_fit$fit, newdata = tail(mtcars)) xy_pred <- unname(xy_pred) - expect_equal(xy_pred, predict_num(xy_fit, new_data = tail(mtcars))) + expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars))) }) diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 7e90bbc43..82d767c4b 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -229,7 +229,7 @@ test_that('ranger regression prediction', { xy_pred <- predict(xy_fit$fit, data = tail(mtcars[, -1]))$prediction - expect_equal(xy_pred, predict_num(xy_fit, new_data = tail(mtcars[, -1]))) + expect_equal(xy_pred, predict_numeric(xy_fit, new_data = tail(mtcars[, -1]))) }) diff --git a/tests/testthat/test_rand_forest_spark.R b/tests/testthat/test_rand_forest_spark.R index bcda93626..4184e6abf 100644 --- a/tests/testthat/test_rand_forest_spark.R +++ b/tests/testthat/test_rand_forest_spark.R @@ -58,7 +58,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_pred_num <- predict_num(spark_reg_fit, iris_rf_te), + spark_reg_pred_num <- predict_numeric(spark_reg_fit, iris_rf_te), regexp = NA ) @@ -68,7 +68,7 @@ test_that('spark execution', { ) expect_error( - spark_reg_num_dup <- predict_num(spark_reg_fit_dup, iris_rf_te), + spark_reg_num_dup <- predict_numeric(spark_reg_fit_dup, iris_rf_te), regexp = NA ) From 620adafb726ca6d0b3519db8e74e7c761c63e5d0 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 25 Oct 2018 10:22:26 -0400 Subject: [PATCH 03/10] classes -> class for #65 --- R/boost_tree.R | 2 +- R/boost_tree_data.R | 6 +++--- R/logistic_reg_data.R | 8 ++++---- R/mars_data.R | 2 +- R/mlp_data.R | 4 ++-- R/multinom_reg_data.R | 4 ++-- R/nearest_neighbor_data.R | 2 +- R/predict_class.R | 12 ++++++------ R/rand_forest_data.R | 6 +++--- 9 files changed, 23 insertions(+), 23 deletions(-) diff --git a/R/boost_tree.R b/R/boost_tree.R index 8c31419ff..8c2614547 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -391,7 +391,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { nms <- names(pred) } else { if (type == "class") { - pred <- boost_tree_xgboost_data$classes$post(pred, object) + pred <- boost_tree_xgboost_data$class$post(pred, object) pred <- tibble(.pred = factor(pred, levels = object$lvl)) } else { pred <- boost_tree_xgboost_data$classprob$post(pred, object) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 20e238aaa..559698511 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -41,7 +41,7 @@ boost_tree_xgboost_data <- newdata = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { if (is.vector(x)) { @@ -97,7 +97,7 @@ boost_tree_C5.0_data <- func = c(pkg = "parsnip", fun = "C5.0_train"), defaults = list() ), - classes = list( + class = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -152,7 +152,7 @@ boost_tree_spark_data <- dataset = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index a39dc137c..e44a8ed2a 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -33,7 +33,7 @@ logistic_reg_glm_data <- family = expr(stats::binomial) ) ), - classes = list( + class = list( pre = NULL, post = prob_to_class_2, func = c(fun = "predict"), @@ -109,7 +109,7 @@ logistic_reg_glmnet_data <- family = "binomial" ) ), - classes = list( + class = list( pre = NULL, post = organize_glmnet_class, func = c(fun = "predict"), @@ -156,7 +156,7 @@ logistic_reg_stan_data <- family = expr(stats::binomial) ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { x <- object$fit$family$linkinv(x) @@ -268,7 +268,7 @@ logistic_reg_spark_data <- family = "binomial" ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/mars_data.R b/R/mars_data.R index e67c5e3b2..0c1076e68 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -34,7 +34,7 @@ mars_earth_data <- type = "response" ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { x <- ifelse(x[,1] >= 0.5, object$lvl[2], object$lvl[1]) diff --git a/R/mlp_data.R b/R/mlp_data.R index d2d6c472f..0e4df7a65 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -35,7 +35,7 @@ mlp_keras_data <- x = quote(as.matrix(new_data)) ) ), - classes = list( + class = list( pre = NULL, post = function(x, object) { object$lvl[x + 1] @@ -92,7 +92,7 @@ mlp_nnet_data <- type = "raw" ) ), - classes = list( + class = list( pre = NULL, post = NULL, func = c(fun = "predict"), diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index b32bacd7e..eca2c8c96 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -28,7 +28,7 @@ multinom_reg_glmnet_data <- family = "multinomial" ) ), - classes = list( + class = list( pre = check_glmnet_lambda, post = organize_multnet_class, func = c(fun = "predict"), @@ -75,7 +75,7 @@ multinom_reg_spark_data <- family = "multinomial" ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 9b322981e..0191d8614 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -42,7 +42,7 @@ nearest_neighbor_kknn_data <- type = "raw" ) ), - classes = list( + class = list( pre = function(x, object) { if (!(object$fit$response %in% c("ordinal", "nominal"))) { stop("`kknn` model does not appear to use class predictions. Was ", diff --git a/R/predict_class.R b/R/predict_class.R index 67e15de95..c586b1039 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -13,23 +13,23 @@ predict_class.model_fit <- function (object, new_data, ...) { stop("`predict.model_fit` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "classes")) + if (!any(names(object$spec$method) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$classes$pre)) - new_data <- object$spec$method$classes$pre(new_data, object) + if (!is.null(object$spec$method$class$pre)) + new_data <- object$spec$method$class$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$classes) + pred_call <- make_pred_call(object$spec$method$class) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$classes$post)) { - res <- object$spec$method$classes$post(res, object) + if(!is.null(object$spec$method$class$post)) { + res <- object$spec$method$class$post(res, object) } # coerce levels to those in `object` diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 5f95de5c2..65eb84864 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -123,7 +123,7 @@ rand_forest_ranger_data <- verbose = FALSE ) ), - classes = list( + class = list( pre = NULL, post = ranger_class_pred, func = c(fun = "predict"), @@ -200,7 +200,7 @@ rand_forest_randomForest_data <- newdata = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = NULL, func = c(fun = "predict"), @@ -257,7 +257,7 @@ rand_forest_spark_data <- dataset = quote(new_data) ) ), - classes = list( + class = list( pre = NULL, post = format_spark_class, func = c(pkg = "sparklyr", fun = "ml_predict"), From 1bab926a32260356226cb05a633190aa1b1142bf Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 25 Oct 2018 10:34:26 -0400 Subject: [PATCH 04/10] closes #65 --- docs/articles/articles/Classification.html | 8 +- docs/articles/articles/Models.html | 112 ++++++++++----------- docs/articles/articles/Scratch.html | 28 +++--- docs/reference/other_predict.html | 4 +- vignettes/articles/Models.Rmd | 4 +- vignettes/articles/Scratch.Rmd | 26 ++--- 6 files changed, 90 insertions(+), 92 deletions(-) diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index 9810d6296..699b7b5a8 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -175,17 +175,17 @@

Classification Example

#> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 roc_auc 0.824 +#> 1 roc_auc 0.823 test_results %>% accuracy(truth = Status, estimate = `nnet class`) #> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 accuracy 0.809 +#> 1 accuracy 0.801 test_results %>% conf_mat(truth = Status, estimate = `nnet class`) #> Truth #> Prediction bad good -#> bad 184 84 -#> good 129 716 +#> bad 175 84 +#> good 138 716 -
+

-The pred module

-

This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For pred, the model requires an unnamed numeric vector output (usually).

+The numeric module +

This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For numeric, the model requires an unnamed numeric vector output (usually).

Examples are here and here.

For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results).

-

Note that the pred module maps to the predict_num function in parsnip. However, the user-facing predict function is used to generate predictions and returns a tibble with a column named .pred (see the example below). When creating new models, you don’t have to write code for that part.

+

Note that the numeric module maps to the predict_numeric function in parsnip. However, the user-facing predict function is used to generate predictions and returns a tibble with a column named .pred (see the example below). When creating new models, you don’t have to write code for that part.

-
+

-The classes module

-

To make hard class predictions, the classes object contains the details. The elements of the list are:

+The class module +

To make hard class predictions, the class object contains the details. The elements of the list are:

  • pre and post are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won’t be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a pre function with the args below.
  • @@ -234,7 +234,7 @@

  • args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.
-
mixture_da_mda_data$classes <-
+
 

The predict_class function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the pred module, the user doesn’t call predict_class but uses predict instead and this produces a tibble with a column named .pred_class per the model guidlines.

-
+

-The prob module

-

This defines the class probabilities (if they can be computed). The format is identical to the classes module but the output is expected to be a tibble with columns for each factor level.

+The classprob module +

This defines the class probabilities (if they can be computed). The format is identical to the class module but the output is expected to be a tibble with columns for each factor level.

As an example of the post function, the data frame created by mda:::predict.mda will be converted to a tibble. The arguments are x (the raw results coming from the predict method) and object (the parsnip model fit object). The latter has a sub-object called lvl which is a character string of the outcome’s factor levels (if any).

-
mixture_da_mda_data$prob <-
+
mixture_da_mda_data$classprob <-
   list(
     pre = NULL,
     post = function(x, object) {
@@ -371,9 +371,9 @@ 

Do I have to return a simple vector for predict_num and predict_class?

-

Previously, when discussing the pred information:

+

Previously, when discussing the numeric information:

-

For pred, the model requires an unnamed numeric vector output (usually).

+

For numeric, the model requires an unnamed numeric vector output (usually).

There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated here but some easy examples are:

    diff --git a/docs/reference/other_predict.html b/docs/reference/other_predict.html index 15d418104..c0a7c3c1c 100644 --- a/docs/reference/other_predict.html +++ b/docs/reference/other_predict.html @@ -155,9 +155,9 @@

    Other predict methods.

    predict_predint(object, ...) # S3 method for model_fit -predict_num(object, new_data, ...) +predict_numeric(object, new_data, ...) -predict_num(object, ...) +predict_numeric(object, ...) # S3 method for model_fit predict_quantile(object, new_data, diff --git a/vignettes/articles/Models.Rmd b/vignettes/articles/Models.Rmd index 35ea2f380..8fd627782 100644 --- a/vignettes/articles/Models.Rmd +++ b/vignettes/articles/Models.Rmd @@ -62,10 +62,8 @@ _How_ the model is created is related to the _engine_. In many cases, this is an map2_dfr(engine_info$model, engine_info$engine, mod_names) %>% dplyr::filter(!(module %in% c("libs", "fit"))) %>% mutate( - module = ifelse(module == "pred", "num", module), module = ifelse(module == "confint", "conf_int", module), module = ifelse(module == "predint", "pred_int", module), - module = ifelse(module == "classes", "class", module), module = paste0("`", module, "`"), model = paste0("`", model, "()`"), ) %>% @@ -76,5 +74,5 @@ _How_ the model is created is related to the _engine_. In many cases, this is an collapse_rows(columns = 1) ``` -Models can be added by the user too. See the "Making a `parsnip` model from scratch" vignette. +Models can be added by the user too. See the ["Making a `parsnip` model from scratch" vignette](Scratch.html). diff --git a/vignettes/articles/Scratch.Rmd b/vignettes/articles/Scratch.Rmd index 22d17a63f..ecccb90d6 100644 --- a/vignettes/articles/Scratch.Rmd +++ b/vignettes/articles/Scratch.Rmd @@ -155,27 +155,27 @@ mixture_da_mda_data$fit <- ) ``` -### The `pred` module +### The `numeric` module -This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For `pred`, the model requires an unnamed numeric vector output (usually). +This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For `numeric`, the model requires an unnamed numeric vector output (usually). Examples are [here](https://github.com/topepo/parsnip/blob/master/R/linear_reg_data.R) and [here](https://github.com/topepo/parsnip/blob/master/R/rand_forest_data.R). For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results). -Note that the `pred` module maps to the `predict_num` function in `parsnip`. However, the user-facing `predict` function is used to generate predictions and returns a tibble with a column named `.pred` (see the example below). When creating new models, you don't have to write code for that part. +Note that the `numeric` module maps to the `predict_numeric` function in `parsnip`. However, the user-facing `predict` function is used to generate predictions and returns a tibble with a column named `.pred` (see the example below). When creating new models, you don't have to write code for that part. -### The `classes` module +### The `class` module -To make hard class predictions, the `classes` object contains the details. The elements of the list are: +To make hard class predictions, the `class` object contains the details. The elements of the list are: * `pre` and `post` are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won't be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a `pre` function with the `args` below. * `func` is the prediction function (in the same format as above). In many cases, packages have a predict method for their model's class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to `predict` with no associated package. * `args` is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in `rlang::expr` so that they are not evaluated when defining the method. For `mda`, the code would be `predict(object, newdata, type = "class")`. What is actually given to the function is the `parsnip` model fit object, which includes a sub-object called `fit` and this houses the `mda` model object. If the data need to be a matrix or data frame, you could also use `new_data = quote(as.data.frame(new_data))` and so on. -```{r mda-classes} -mixture_da_mda_data$classes <- +```{r mda-class} +mixture_da_mda_data$class <- list( pre = NULL, post = NULL, @@ -196,14 +196,14 @@ mixture_da_mda_data$classes <- The `predict_class` function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the `pred` module, the user doesn't call `predict_class` but uses `predict` instead and this produces a tibble with a column named `.pred_class` [per the model guidlines](https://tidymodels.github.io/model-implementation-principles/model-predictions.html#return-values). -### The `prob` module +### The `classprob` module -This defines the class probabilities (if they can be computed). The format is identical to the `classes` module but the output is expected to be a tibble with columns for each factor level. +This defines the class probabilities (if they can be computed). The format is identical to the `class` module but the output is expected to be a tibble with columns for each factor level. As an example of the `post` function, the data frame created by `mda:::predict.mda` will be converted to a tibble. The arguments are `x` (the raw results coming from the predict method) and `object` (the `parsnip` model fit object). The latter has a sub-object called `lvl` which is a character string of the outcome's factor levels (if any). -```{r mda-prob} -mixture_da_mda_data$prob <- +```{r mda-classprob} +mixture_da_mda_data$classprob <- list( pre = NULL, post = function(x, object) { @@ -263,9 +263,9 @@ There are various things that came to mind while writing this document. ### Do I have to return a simple vector for `predict_num` and `predict_class`? -Previously, when discussing the `pred` information: +Previously, when discussing the `numeric` information: -> For `pred`, the model requires an unnamed numeric vector output **(usually)**. +> For `numeric`, the model requires an unnamed numeric vector output **(usually)**. There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated [here](https://tidymodels.github.io/model-implementation-principles/notes.html#list-cols) but some easy examples are: From 4964061ddfe2e51b496a3021c9108761875491f4 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 25 Oct 2018 19:51:40 -0400 Subject: [PATCH 05/10] svm models --- NAMESPACE | 8 + R/svm_poly.R | 196 +++++++++++++ R/svm_poly_data.R | 69 +++++ R/svm_rbf.R | 183 +++++++++++++ R/svm_rbf_data.R | 69 +++++ R/translate.R | 5 + docs/articles/articles/Classification.html | 8 +- docs/articles/articles/Models.html | 62 ++++- docs/reference/svm_poly.html | 304 +++++++++++++++++++++ docs/reference/svm_rbf.html | 300 ++++++++++++++++++++ docs/reference/translate.html | 3 + man/svm_poly.Rd | 90 ++++++ man/svm_rbf.Rd | 88 ++++++ man/translate.Rd | 4 + tests/testthat/test_svm_poly.R | 270 ++++++++++++++++++ tests/testthat/test_svm_rbf.R | 245 +++++++++++++++++ 16 files changed, 1898 insertions(+), 6 deletions(-) create mode 100644 R/svm_poly.R create mode 100644 R/svm_poly_data.R create mode 100644 R/svm_rbf.R create mode 100644 R/svm_rbf_data.R create mode 100644 docs/reference/svm_poly.html create mode 100644 docs/reference/svm_rbf.html create mode 100644 man/svm_poly.Rd create mode 100644 man/svm_rbf.Rd create mode 100644 tests/testthat/test_svm_poly.R create mode 100644 tests/testthat/test_svm_rbf.R diff --git a/NAMESPACE b/NAMESPACE index 06caa62cb..f067c2785 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -38,12 +38,16 @@ S3method(print,multinom_reg) S3method(print,nearest_neighbor) S3method(print,rand_forest) S3method(print,surv_reg) +S3method(print,svm_poly) +S3method(print,svm_rbf) S3method(translate,boost_tree) S3method(translate,default) S3method(translate,mars) S3method(translate,mlp) S3method(translate,rand_forest) S3method(translate,surv_reg) +S3method(translate,svm_poly) +S3method(translate,svm_rbf) S3method(type_sum,model_fit) S3method(type_sum,model_spec) S3method(update,boost_tree) @@ -55,6 +59,8 @@ S3method(update,multinom_reg) S3method(update,nearest_neighbor) S3method(update,rand_forest) S3method(update,surv_reg) +S3method(update,svm_poly) +S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) @@ -106,6 +112,8 @@ export(set_engine) export(set_mode) export(show_call) export(surv_reg) +export(svm_poly) +export(svm_rbf) export(translate) export(varying) export(varying_args) diff --git a/R/svm_poly.R b/R/svm_poly.R new file mode 100644 index 000000000..adb6d043d --- /dev/null +++ b/R/svm_poly.R @@ -0,0 +1,196 @@ +#' General interface for polynomial support vector machines +#' +#' `svm_poly` is a way to generate a _specification_ of a model +#' before fitting and allows the model to be created using +#' different packages in R or via Spark. The main arguments for the +#' model are: +#' \itemize{ +#' \item \code{cost}: The cost of predicting a sample within or on the +#' wrong side of the margin. +#' \item \code{degree}: The polynomial degree. +#' \item \code{scale_factor}: A scaling factor for the kernel. +#' \item \code{margin}: The epsilon in the SVM insensitive loss function +#' (regression only) +#' } +#' These arguments are converted to their specific names at the +#' time that the model is fit. Other options and argument can be +#' set using `set_engine`. If left to their defaults +#' here (`NULL`), the values are taken from the underlying model +#' functions. If parameters need to be modified, `update` can be used +#' in lieu of recreating the object from scratch. +#' +#' @inheritParams boost_tree +#' @param mode A single character string for the type of model. +#' Possible values for this model are "unknown", "regression", or +#' "classification". +#' @param cost A positive number for the cost of predicting a sample within +#' or on the wrong side of the margin +#' @param degree A positive number for polynomial degree. +#' @param scale_factor A positive number for the polynomial scaling factor. +#' @param margin A positive number for the epsilon in the SVM insensitive +#' loss function (regression only) +#' @details +#' The model can be created using the `fit()` function using the +#' following _engines_: +#' \itemize{ +#' \item \pkg{R}: `"kernlab"` +#' } +#' +#' @section Engine Details: +#' +#' Engines may have pre-set default arguments when executing the +#' model fit call. For this type of +#' model, the template of the fit calls are:: +#' +#' \pkg{kernlab} classification +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} +#' +#' \pkg{kernlab} regression +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} +#' +#' @importFrom purrr map_lgl +#' @seealso [varying()], [fit()] +#' @examples +#' svm_poly(mode = "classification", degree = 1.2) +#' # Parameters can be represented by a placeholder: +#' svm_poly(mode = "regression", cost = varying()) +#' @export + +svm_poly <- + function(mode = "unknown", + cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL) { + + args <- list( + cost = enquo(cost), + degree = enquo(degree), + scale_factor = enquo(scale_factor), + margin = enquo(margin) + ) + + new_model_spec( + "svm_poly", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + } + +#' @export +print.svm_poly <- function(x, ...) { + cat("Polynomial Support Vector Machine Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if(!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + invisible(x) +} + +# ------------------------------------------------------------------------------ + +#' @export +#' @inheritParams update.boost_tree +#' @param object A polynomial SVM model specification. +#' @examples +#' model <- svm_poly(cost = 10, scale_factor = 0.1) +#' model +#' update(model, cost = 1) +#' update(model, cost = 1, fresh = TRUE) +#' @method update svm_poly +#' @rdname svm_poly +#' @export +update.svm_poly <- + function(object, + cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, + fresh = FALSE, + ...) { + update_dot_check(...) + + args <- list( + cost = enquo(cost), + degree = enquo(degree), + scale_factor = enquo(scale_factor), + margin = enquo(margin) + ) + + if (fresh) { + object$args <- args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + } + + new_model_spec( + "svm_poly", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) + } + +# ------------------------------------------------------------------------------ + +#' @export +translate.svm_poly <- function(x, engine = x$engine, ...) { + x <- translate.default(x, engine = engine, ...) + + # slightly cleaner code using + arg_vals <- x$method$fit$args + arg_names <- names(arg_vals) + + # add checks to error trap or change things for this method + if (x$engine == "kernlab") { + + # unless otherwise specified, classification models predict probabilities + if (x$mode == "classification" && !any(arg_names == "prob.model")) + arg_vals$prob.model <- TRUE + if (x$mode == "classification" && any(arg_names == "epsilon")) + arg_vals$epsilon <- NULL + + # convert degree and scale to a `kpar` argument. + if (any(arg_names %in% c("degree", "scale", "offset"))) { + kpar <- expr(list()) + if (any(arg_names == "degree")) { + kpar$degree <- arg_vals$degree + arg_vals$degree <- NULL + } + if (any(arg_names == "scale")) { + kpar$scale <- arg_vals$scale + arg_vals$scale <- NULL + } + if (any(arg_names == "offset")) { + kpar$offset <- arg_vals$offset + arg_vals$offset <- NULL + } + arg_vals$kpar <- kpar + } + + } + x$method$fit$args <- arg_vals + + # worried about people using this to modify the specification + x +} + +# ------------------------------------------------------------------------------ + +check_args.svm_poly <- function(object) { + invisible(object) +} + +# ------------------------------------------------------------------------------ + +svm_reg_post <- function(results, object) { + results[,1] +} + diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R new file mode 100644 index 000000000..04b0bc55e --- /dev/null +++ b/R/svm_poly_data.R @@ -0,0 +1,69 @@ +svm_poly_arg_key <- data.frame( + kernlab = c( "C", "degree", "scale", "epsilon"), + row.names = c("cost", "degree", "scale_factor", "margin"), + stringsAsFactors = FALSE +) + +svm_poly_modes <- c("classification", "regression", "unknown") + +svm_poly_engines <- data.frame( + kernlab = c(TRUE, TRUE, FALSE), + row.names = c("classification", "regression", "unknown") +) + +# ------------------------------------------------------------------------------ + +svm_poly_kernlab_data <- + list( + libs = "kernlab", + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list( + kernel = "polydot" + ) + ), + numeric = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + class = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + classprob = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ), + raw = list( + pre = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) + ) diff --git a/R/svm_rbf.R b/R/svm_rbf.R new file mode 100644 index 000000000..7670dfc75 --- /dev/null +++ b/R/svm_rbf.R @@ -0,0 +1,183 @@ +#' General interface for radial basis function support vector machines +#' +#' `svm_rbf` is a way to generate a _specification_ of a model +#' before fitting and allows the model to be created using +#' different packages in R or via Spark. The main arguments for the +#' model are: +#' \itemize{ +#' \item \code{cost}: The cost of predicting a sample within or on the +#' wrong side of the margin. +#' \item \code{rbf_sigma}: The precision parameter for the radial basis +#' function. +#' \item \code{margin}: The epsilon in the SVM insensitive loss function +#' (regression only) +#' } +#' These arguments are converted to their specific names at the +#' time that the model is fit. Other options and argument can be +#' set using `set_engine`. If left to their defaults +#' here (`NULL`), the values are taken from the underlying model +#' functions. If parameters need to be modified, `update` can be used +#' in lieu of recreating the object from scratch. +#' +#' @inheritParams boost_tree +#' @param mode A single character string for the type of model. +#' Possible values for this model are "unknown", "regression", or +#' "classification". +#' @param cost A positive number for the cost of predicting a sample within +#' or on the wrong side of the margin +#' @param rbf_sigma A positive number for radial basis function. +#' @param margin A positive number for the epsilon in the SVM insensitive +#' loss function (regression only) +#' @details +#' The model can be created using the `fit()` function using the +#' following _engines_: +#' \itemize{ +#' \item \pkg{R}: `"kernlab"` +#' } +#' +#' @section Engine Details: +#' +#' Engines may have pre-set default arguments when executing the +#' model fit call. For this type of +#' model, the template of the fit calls are:: +#' +#' \pkg{kernlab} classification +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} +#' +#' \pkg{kernlab} regression +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} +#' +#' @importFrom purrr map_lgl +#' @seealso [varying()], [fit()] +#' @examples +#' svm_rbf(mode = "classification", rbf_sigma = 0.2) +#' # Parameters can be represented by a placeholder: +#' svm_rbf(mode = "regression", cost = varying()) +#' @export + +svm_rbf <- + function(mode = "unknown", + cost = NULL, rbf_sigma = NULL, margin = NULL) { + + args <- list( + cost = enquo(cost), + rbf_sigma = enquo(rbf_sigma), + margin = enquo(margin) + ) + + new_model_spec( + "svm_rbf", + args = args, + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + } + +#' @export +print.svm_rbf <- function(x, ...) { + cat("Radial Basis Function Support Vector Machine Specification (", x$mode, ")\n\n", sep = "") + model_printer(x, ...) + + if(!is.null(x$method$fit$args)) { + cat("Model fit template:\n") + print(show_call(x)) + } + invisible(x) +} + +# ------------------------------------------------------------------------------ + +#' @export +#' @inheritParams update.boost_tree +#' @param object A radial basis function SVM model specification. +#' @examples +#' model <- svm_rbf(cost = 10, rbf_sigma = 0.1) +#' model +#' update(model, cost = 1) +#' update(model, cost = 1, fresh = TRUE) +#' @method update svm_rbf +#' @rdname svm_rbf +#' @export +update.svm_rbf <- + function(object, + cost = NULL, rbf_sigma = NULL, margin = NULL, + fresh = FALSE, + ...) { + update_dot_check(...) + + args <- list( + cost = enquo(cost), + rbf_sigma = enquo(rbf_sigma), + margin = enquo(margin) + ) + + if (fresh) { + object$args <- args + } else { + null_args <- map_lgl(args, null_value) + if (any(null_args)) + args <- args[!null_args] + if (length(args) > 0) + object$args[names(args)] <- args + } + + new_model_spec( + "svm_rbf", + args = object$args, + eng_args = object$eng_args, + mode = object$mode, + method = NULL, + engine = object$engine + ) + } + +# ------------------------------------------------------------------------------ + +#' @export +translate.svm_rbf <- function(x, engine = x$engine, ...) { + x <- translate.default(x, engine = engine, ...) + + # slightly cleaner code using + arg_vals <- x$method$fit$args + arg_names <- names(arg_vals) + + # add checks to error trap or change things for this method + if (x$engine == "kernlab") { + + # unless otherwise specified, classification models predict probabilities + if (x$mode == "classification" && !any(arg_names == "prob.model")) + arg_vals$prob.model <- TRUE + if (x$mode == "classification" && any(arg_names == "epsilon")) + arg_vals$epsilon <- NULL + + # convert sigma and scale to a `kpar` argument. + if (any(arg_names == "sigma")) { + kpar <- expr(list()) + kpar$sigma <- arg_vals$sigma + arg_vals$sigma <- NULL + arg_vals$kpar <- kpar + } + + } + x$method$fit$args <- arg_vals + + # worried about people using this to modify the specification + x +} + +# ------------------------------------------------------------------------------ + +check_args.svm_rbf <- function(object) { + invisible(object) +} + +# ------------------------------------------------------------------------------ + +svm_reg_post <- function(results, object) { + results[,1] +} + diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R new file mode 100644 index 000000000..fdb12727d --- /dev/null +++ b/R/svm_rbf_data.R @@ -0,0 +1,69 @@ +svm_rbf_arg_key <- data.frame( + kernlab = c( "C", "sigma", "epsilon"), + row.names = c("cost", "rbf_sigma", "margin"), + stringsAsFactors = FALSE +) + +svm_rbf_modes <- c("classification", "regression", "unknown") + +svm_rbf_engines <- data.frame( + kernlab = c(TRUE, TRUE, FALSE), + row.names = c("classification", "regression", "unknown") +) + +# ------------------------------------------------------------------------------ + +svm_rbf_kernlab_data <- + list( + libs = "kernlab", + fit = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list( + kernel = "rbfdot" + ) + ), + numeric = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + class = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ), + classprob = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ), + raw = list( + pre = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) + ) diff --git a/R/translate.R b/R/translate.R index 7c88e8c18..2c342e360 100644 --- a/R/translate.R +++ b/R/translate.R @@ -19,6 +19,11 @@ #' This function can be useful when you need to understand how #' `parsnip` goes from a generic model specific to a model fitting #' function. +#' +#' **Note**: this function is used internally and users should only use it +#' to understand what the underlying syntax would be. It should not be used +#' to modify the model specification. +#' #' @examples #' lm_spec <- linear_reg(penalty = 0.01) #' diff --git a/docs/articles/articles/Classification.html b/docs/articles/articles/Classification.html index 699b7b5a8..04769fe2a 100644 --- a/docs/articles/articles/Classification.html +++ b/docs/articles/articles/Classification.html @@ -175,17 +175,17 @@

    Classification Example

    #> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 roc_auc 0.823 +#> 1 roc_auc 0.821 test_results %>% accuracy(truth = Status, estimate = `nnet class`) #> # A tibble: 1 x 2 #> .metric .estimate #> <chr> <dbl> -#> 1 accuracy 0.801 +#> 1 accuracy 0.794 test_results %>% conf_mat(truth = Status, estimate = `nnet class`) #> Truth #> Prediction bad good -#> bad 175 84 -#> good 138 716

+#> bad 169 85 +#> good 144 715
+#> bad 174 86 +#> good 139 714
+#> bad 176 81 +#> good 137 719

As an example of a model with multiple engines, here is the object for logistic regression:

+#> glm glmnet spark stan keras +#> penalty NA lambda reg_param NA decay +#> mixture NA alpha elastic_net_param NA <NA>

The internals of parsnip will use these objects during the creation of the model code.

diff --git a/docs/reference/index.html b/docs/reference/index.html index df9884277..3682273be 100644 --- a/docs/reference/index.html +++ b/docs/reference/index.html @@ -190,6 +190,18 @@

surv_reg() update(<surv_reg>)

General Interface for Parametric Survival Models

+ + + +

svm_poly() update(<svm_poly>)

+ +

General interface for polynomial support vector machines

+ + + +

svm_rbf() update(<svm_rbf>)

+ +

General interface for radial basis function support vector machines

diff --git a/docs/reference/linear_reg.html b/docs/reference/linear_reg.html index 1d692e35f..b6961b5fd 100644 --- a/docs/reference/linear_reg.html +++ b/docs/reference/linear_reg.html @@ -33,8 +33,8 @@ Arg penalty -

An non-negative number representing the -total amount of regularization (glmnet and spark only).

+

An non-negative number representing the total +amount of regularization (glmnet, keras, and spark only). +For keras models, this corresponds to purely L2 regularization +(aka weight decay) while the other models can be a combination +of L1 and L2 (depending on the value of mixture).

mixture @@ -207,6 +210,7 @@

Details
  • R: "glm" or "glmnet"

  • Stan: "stan"

  • Spark: "spark"

  • +
  • keras: "keras"

  • Note

    @@ -250,6 +254,11 @@

    +

    keras

    +

    +parsnip::keras_mlp(x = missing_arg(), y = missing_arg(), hidden_units = 1, 
    +    act = "linear")
    +

    When using glmnet models, there is the option to pass multiple values (or no values) to the penalty argument. This can have an effect on the model object results. When using diff --git a/docs/reference/multinom_reg.html b/docs/reference/multinom_reg.html index 6765274a5..87e2cfd61 100644 --- a/docs/reference/multinom_reg.html +++ b/docs/reference/multinom_reg.html @@ -33,7 +33,7 @@ Arg penalty -

    An non-negative number representing the -total amount of regularization.

    +

    An non-negative number representing the total +amount of regularization (glmnet, keras, and spark only). +For keras models, this corresponds to purely L2 regularization +(aka weight decay) while the other models can be a combination +of L1 and L2 (depending on the value of mixture).

    mixture @@ -206,6 +209,7 @@

    Details following engines:

    • R: "glmnet"

    • Stan: "stan"

    • +
    • keras: "keras"

    Note

    @@ -239,6 +243,11 @@

    +

    keras

    +

    +parsnip::keras_mlp(x = missing_arg(), y = missing_arg(), hidden_units = 1, 
    +    act = "linear")
    +

    When using glmnet models, there is the option to pass multiple values (or no values) to the penalty argument. This can have an effect on the model object results. When using diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index 0b2918a46..0a3416c0c 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -14,8 +14,11 @@ logistic_reg(mode = "classification", penalty = NULL, mixture = NULL) \item{mode}{A single character string for the type of model. The only possible value for this model is "classification".} -\item{penalty}{An non-negative number representing the -total amount of regularization (\code{glmnet} and \code{spark} only).} +\item{penalty}{An non-negative number representing the total +amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only). +For \code{keras} models, this corresponds to purely L2 regularization +(aka weight decay) while the other models can be a combination +of L1 and L2 (depending on the value of \code{mixture}).} \item{mixture}{A number between zero and one (inclusive) that represents the proportion of regularization that is used for the @@ -32,8 +35,8 @@ modified in-place of or replaced wholesale.} \description{ \code{logistic_reg} is a way to generate a \emph{specification} of a model before fitting and allows the model to be created using -different packages in R, Stan, or via Spark. The main arguments for the -model are: +different packages in R, Stan, keras, or via Spark. The main +arguments for the model are: \itemize{ \item \code{penalty}: The total amount of regularization in the model. Note that this must be zero for some engines. @@ -56,6 +59,7 @@ following \emph{engines}: \item \pkg{R}: \code{"glm"} or \code{"glmnet"} \item \pkg{Stan}: \code{"stan"} \item \pkg{Spark}: \code{"spark"} +\item \pkg{keras}: \code{"keras"} } } \note{ @@ -95,6 +99,10 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} +\pkg{keras} + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} + When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index 2bcc20b3e..0de9c5ee1 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -14,8 +14,11 @@ multinom_reg(mode = "classification", penalty = NULL, mixture = NULL) \item{mode}{A single character string for the type of model. The only possible value for this model is "classification".} -\item{penalty}{An non-negative number representing the -total amount of regularization.} +\item{penalty}{An non-negative number representing the total +amount of regularization (\code{glmnet}, \code{keras}, and \code{spark} only). +For \code{keras} models, this corresponds to purely L2 regularization +(aka weight decay) while the other models can be a combination +of L1 and L2 (depending on the value of \code{mixture}).} \item{mixture}{A number between zero and one (inclusive) that represents the proportion of regularization that is used for the @@ -32,7 +35,7 @@ modified in-place of or replaced wholesale.} \description{ \code{multinom_reg} is a way to generate a \emph{specification} of a model before fitting and allows the model to be created using -different packages in R or Spark. The main arguments for the +different packages in R, keras, or Spark. The main arguments for the model are: \itemize{ \item \code{penalty}: The total amount of regularization @@ -55,6 +58,7 @@ following \emph{engines}: \itemize{ \item \pkg{R}: \code{"glmnet"} \item \pkg{Stan}: \code{"stan"} +\item \pkg{keras}: \code{"keras"} } } \note{ @@ -86,6 +90,10 @@ model, the template of the fit calls are: \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} +\pkg{keras} + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} + When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using diff --git a/tests/testthat/test_logistic_reg_keras.R b/tests/testthat/test_logistic_reg_keras.R new file mode 100644 index 000000000..8a7c51986 --- /dev/null +++ b/tests/testthat/test_logistic_reg_keras.R @@ -0,0 +1,189 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("keras logistic regression") +source("helpers.R") + +# ------------------------------------------------------------------------------ + +data("lending_club") +set.seed(352) +dat <- + lending_club %>% + group_by(Class) %>% + sample_n(500) %>% + ungroup() %>% + dplyr::select(Class, funded_amnt, int_rate) +dat <- dat[order(runif(nrow(dat))),] + +tr_dat <- dat[1:995, ] +te_dat <- dat[996:1000, ] + +# ------------------------------------------------------------------------------ + +basic_mod <- + logistic_reg() %>% + set_engine("keras", epochs = 50, verbose = 0) + +reg_mod <- + logistic_reg(penalty = 0.1) %>% + set_engine("keras", epochs = 50, verbose = 0) + +ctrl <- fit_control(verbosity = 0, catch = FALSE) + +# ------------------------------------------------------------------------------ + +test_that('model fitting', { + + skip_if_not_installed("keras") + + set.seed(257) + expect_error( + fit1 <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ), + regexp = NA + ) + + set.seed(257) + expect_error( + fit2 <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ), + regexp = NA + ) + expect_equal(fit1, fit2) + + expect_error( + fit( + basic_mod, + Class ~ ., + data = tr_dat, + control = ctrl + ), + regexp = NA + ) + + expect_error( + fit1 <- + fit_xy( + reg_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ), + regexp = NA + ) + + expect_error( + fit( + reg_mod, + Class ~ ., + data = tr_dat, + control = ctrl + ), + regexp = NA + ) + +}) + + +test_that('classification prediction', { + + skip_if_not_installed("keras") + + library(keras) + + set.seed(257) + lr_fit <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ) + + keras_raw <- + predict(lr_fit$fit, as.matrix(te_dat[, -1])) + keras_pred <- + tibble(.pred_class = apply(keras_raw, 1, which.max)) %>% + mutate(.pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl)) + + parsnip_pred <- predict(lr_fit, te_dat[, -1]) + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + + set.seed(257) + plrfit <- + fit_xy( + reg_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ) + + keras_raw <- + predict(plrfit$fit, as.matrix(te_dat[, -1])) + keras_pred <- + tibble(.pred_class = apply(keras_raw, 1, which.max)) %>% + mutate(.pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl)) + parsnip_pred <- predict(plrfit, te_dat[, -1]) + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + +}) + + +test_that('classification probabilities', { + + skip_if_not_installed("keras") + + library(keras) + + set.seed(257) + lr_fit <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ) + + keras_pred <- + predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>% + as_tibble() %>% + setNames(paste0(".pred_", lr_fit$lvl)) + + parsnip_pred <- predict(lr_fit, te_dat[, -1], type = "prob") + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + + set.seed(257) + plrfit <- + fit_xy( + reg_mod, + control = ctrl, + x = tr_dat[, -1], + y = tr_dat$Class + ) + + keras_pred <- + predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>% + as_tibble() %>% + setNames(paste0(".pred_", lr_fit$lvl)) + parsnip_pred <- predict(plrfit, te_dat[, -1], type = "prob") + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + +}) + + diff --git a/tests/testthat/test_multinom_reg_keras.R b/tests/testthat/test_multinom_reg_keras.R new file mode 100644 index 000000000..faaa56c0f --- /dev/null +++ b/tests/testthat/test_multinom_reg_keras.R @@ -0,0 +1,182 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("keras logistic regression") +source("helpers.R") + +# ------------------------------------------------------------------------------ + +set.seed(352) +dat <- iris[order(runif(nrow(iris))),] + +tr_dat <- dat[1:140, ] +te_dat <- dat[141:150, ] + +# ------------------------------------------------------------------------------ + +basic_mod <- + multinom_reg() %>% + set_engine("keras", epochs = 50, verbose = 0) + +reg_mod <- + multinom_reg(penalty = 0.1) %>% + set_engine("keras", epochs = 50, verbose = 0) + +ctrl <- fit_control(verbosity = 0, catch = FALSE) + +# ------------------------------------------------------------------------------ + +test_that('model fitting', { + + skip_if_not_installed("keras") + + set.seed(257) + expect_error( + fit1 <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ), + regexp = NA + ) + + set.seed(257) + expect_error( + fit2 <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ), + regexp = NA + ) + expect_equal(fit1, fit2) + + expect_error( + fit( + basic_mod, + Species ~ ., + data = tr_dat, + control = ctrl + ), + regexp = NA + ) + + expect_error( + fit1 <- + fit_xy( + reg_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ), + regexp = NA + ) + + expect_error( + fit( + reg_mod, + Species ~ ., + data = tr_dat, + control = ctrl + ), + regexp = NA + ) + +}) + + +test_that('classification prediction', { + + skip_if_not_installed("keras") + + library(keras) + + set.seed(257) + lr_fit <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ) + + keras_raw <- + predict(lr_fit$fit, as.matrix(te_dat[, -5])) + keras_pred <- + tibble(.pred_class = apply(keras_raw, 1, which.max)) %>% + mutate(.pred_class = factor(lr_fit$lvl[.pred_class], levels = lr_fit$lvl)) + + parsnip_pred <- predict(lr_fit, te_dat[, -5]) + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + + set.seed(257) + plrfit <- + fit_xy( + reg_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ) + + keras_raw <- + predict(plrfit$fit, as.matrix(te_dat[, -5])) + keras_pred <- + tibble(.pred_class = apply(keras_raw, 1, which.max)) %>% + mutate(.pred_class = factor(plrfit$lvl[.pred_class], levels = plrfit$lvl)) + parsnip_pred <- predict(plrfit, te_dat[, -5]) + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + +}) + + +test_that('classification probabilities', { + + skip_if_not_installed("keras") + + library(keras) + + set.seed(257) + lr_fit <- + fit_xy( + basic_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ) + + keras_pred <- + predict_proba(lr_fit$fit, as.matrix(te_dat[, -5])) %>% + as_tibble() %>% + setNames(paste0(".pred_", lr_fit$lvl)) + + parsnip_pred <- predict(lr_fit, te_dat[, -5], type = "prob") + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + + set.seed(257) + plrfit <- + fit_xy( + reg_mod, + control = ctrl, + x = tr_dat[, -5], + y = tr_dat$Species + ) + + keras_pred <- + predict_proba(plrfit$fit, as.matrix(te_dat[, -5])) %>% + as_tibble() %>% + setNames(paste0(".pred_", lr_fit$lvl)) + parsnip_pred <- predict(plrfit, te_dat[, -5], type = "prob") + expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred)) + +}) + + From 90e15147785b088c45a6a526f2e5a6514f329c84 Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 26 Oct 2018 15:28:36 -0400 Subject: [PATCH 10/10] missing probability module --- R/mlp_data.R | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/R/mlp_data.R b/R/mlp_data.R index 0e4df7a65..4eac89dcc 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -72,6 +72,17 @@ mlp_keras_data <- ) ) + +nnet_softmax <- function(results, object) { + if (ncol(results) == 1) + results <- cbind(1 - results, results) + + results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) + results <- as_tibble(t(results)) + names(results) <- paste0(".pred_", object$lvl) + results +} + mlp_nnet_data <- list( libs = "nnet", @@ -103,6 +114,17 @@ mlp_nnet_data <- type = "class" ) ), + classprob = list( + pre = NULL, + post = nnet_softmax, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ), raw = list( pre = NULL, func = c(fun = "predict"), @@ -114,6 +136,7 @@ mlp_nnet_data <- ) ) + # ------------------------------------------------------------------------------ # keras wrapper for feed-forward nnet