diff --git a/.Rbuildignore b/.Rbuildignore index 1bf63fb83..387dedcff 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -7,3 +7,6 @@ ^\.Rproj\.user$ ^.travis.yml$ ^R/README\.md$ +derby.log +^logs$ +^tests/testthat/logs$ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 058237471..99bec7b71 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ tests/testthat/derby.log tests/testthat/logs/ *.history +derby.log +logs/* diff --git a/.travis.yml b/.travis.yml index ebf585975..e45533ed3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,50 +8,43 @@ sudo: true warnings_are_errors: false r: -- 3.1 -- 3.2 -- oldrel -- release -- devel + - 3.1 + - 3.2 + - oldrel + - release + - devel -env: - global: - - KERAS_BACKEND="tensorflow" - - MAKEFLAGS="-j 2" -# until we troubleshoot these issues matrix: allow_failures: - r: 3.1 - r: 3.2 r_binary_packages: - - rstan - - rstanarm - - RCurl - - dplyr - - glue - - magrittr - - stringi - - stringr - - munsell - - rlang - - reshape2 - - scales - - tibble - - ggplot2 - - StanHeaders - - Rcpp - - RcppEigen - - BH - - glmnet - - earth - - sparklyr - - flexsurv - - ranger - - randomforest - - xgboost - - C50 + - RCurl + - dplyr + - glue + - magrittr + - stringi + - stringr + - munsell + - rlang + - reshape2 + - scales + - tibble + - ggplot2 + - Rcpp + - RcppEigen + - BH + - glmnet + - earth + - sparklyr + - flexsurv + - ranger + - randomforest + - xgboost + - C50 + cache: packages: true @@ -59,16 +52,32 @@ cache: - $HOME/.keras - $HOME/.cache/pip +env: + global: + - KERAS_BACKEND="tensorflow" + - MAKEFLAGS="-j 2" + +addons: + apt: + sources: + - ubuntu-toolchain-r-test + packages: + g++-6 before_script: - python -m pip install --upgrade --ignore-installed --user travis pip setuptools wheel virtualenv - python -m pip install --upgrade --ignore-installed --user travis keras h5py pyyaml requests Pillow scipy theano - R -e 'tensorflow::install_tensorflow()' + before_install: - sudo apt-get -y install libnlopt-dev - sudo apt-get update - sudo apt-get -y install python3 + - mkdir -p ~/.R && echo "CXX14=g++-6" > ~/.R/Makevars + - echo "CXX14FLAGS += -fPIC" >> ~/.R/Makevars + after_success: - Rscript -e 'covr::codecov()' + diff --git a/DESCRIPTION b/DESCRIPTION index 0f8cf494c..ff4982859 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -38,4 +38,14 @@ Suggests: keras, xgboost, covr, - sparklyr + C50, + sparklyr, + earth, + glmnet, + kernlab, + kknn, + randomForest, + ranger, + rpart, + MASS, + nlme diff --git a/NEWS.md b/NEWS.md index 16b00685d..74b309101 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,8 +14,11 @@ that are actually varying). * `fit_control()` not returns an S3 method. +* For classification models, an error occurs if the outcome data are not encoded as factors (#115). + * The prediction modules (e.g. `predict_class`, `predict_numeric`, etc) were de-exported. These were internal functions that were not to be used by the users and the users were using them. + ## Bug Fixes * `varying_args()` now uses the version from the `generics` package. This means diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 23f8cb9e0..74f614ede 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -6,13 +6,13 @@ #' @importFrom stats model.frame model.response terms as.formula model.matrix form_form <- function(object, control, env, ...) { - opts <- quos(...) - if (object$mode != "regression") { - y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels - env$formula, - env$data - ) + if (object$mode == "classification") { + # prob rewrite this as simple subset/levels + y_levels <- levels_from_formula(env$formula, env$data) + if (!inherits(env$data, "tbl_spark") && is.null(y_levels)) + stop("For classification models, the outcome should be a factor.", + call. = FALSE) } else { y_levels <- NULL } @@ -20,7 +20,7 @@ form_form <- object <- check_mode(object, y_levels) # if descriptors are needed, update descr_env with the calculated values - if(requires_descrs(object)) { + if (requires_descrs(object)) { data_stats <- get_descr_form(env$formula, env$data) scoped_descrs(data_stats) } @@ -71,8 +71,14 @@ xy_xy <- function(object, env, control, target = "none", ...) { object <- check_mode(object, levels(env$y)) + if (object$mode == "classification") { + if (is.null(levels(env$y))) + stop("For classification models, the outcome should be a factor.", + call. = FALSE) + } + # if descriptors are needed, update descr_env with the calculated values - if(requires_descrs(object)) { + if (requires_descrs(object)) { data_stats <- get_descr_form(env$formula, env$data) scoped_descrs(data_stats) } @@ -125,13 +131,12 @@ form_xy <- function(object, control, env, env$x <- data_obj$x env$y <- data_obj$y - res <- list( - lvl = levels_from_formula( - env$formula, - env$data - ), - spec = object - ) + res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object) + if (object$mode == "classification") { + if (is.null(res$lvl)) + stop("For classification models, the outcome should be a factor.", + call. = FALSE) + } res <- xy_xy( object = object, @@ -148,6 +153,13 @@ form_xy <- function(object, control, env, } xy_form <- function(object, env, control, ...) { + + if (object$mode == "classification") { + if (is.null(levels(env$y))) + stop("For classification models, the outcome should be a factor.", + call. = FALSE) + } + data_obj <- convert_xy_to_form_fit( x = env$x, diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 8bc0deed8..05c0b2275 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -168,7 +168,7 @@ check_args.multinom_reg <- function(object) { args <- lapply(object$args, rlang::eval_tidy) - if (is.numeric(args$penalty) && args$penalty < 0) + if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) stop("The amount of regularization should be >= 0", call. = FALSE) if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1)) stop("The mixture proportion should be within [0,1]", call. = FALSE) diff --git a/tests/testthat/test_boost_tree_C50.R b/tests/testthat/test_boost_tree_C50.R index 3d1f0e911..d20578d45 100644 --- a/tests/testthat/test_boost_tree_C50.R +++ b/tests/testthat/test_boost_tree_C50.R @@ -1,6 +1,7 @@ library(testthat) library(parsnip) library(tibble) +library(dplyr) # ------------------------------------------------------------------------------ @@ -8,6 +9,9 @@ context("boosted tree execution with C5.0") data("lending_club") lending_club <- head(lending_club, 200) +lending_club_fail <- + lending_club %>% + mutate(bad = Inf, miss = NA) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") lc_basic <- boost_tree(mode = "classification") %>% @@ -41,6 +45,8 @@ test_that('C5.0 execution', { ), regexp = NA ) + + # outcome is not a factor: expect_error( res <- fit( lc_basic, @@ -51,19 +57,21 @@ test_that('C5.0 execution', { ) ) + # Model fails C5.0_form_catch <- fit( lc_basic, - funded_amnt ~ term, - data = lending_club, + Class ~ miss, + data = lending_club_fail, control = caught_ctrl ) expect_true(inherits(C5.0_form_catch$fit, "try-error")) + # Model fails C5.0_xy_catch <- fit_xy( lc_basic, control = caught_ctrl, - x = lending_club[, num_pred], - y = lending_club$total_bal_il + x = lending_club_fail[, "miss"], + y = lending_club_fail$Class ) expect_true(inherits(C5.0_xy_catch$fit, "try-error")) }) @@ -108,11 +116,12 @@ test_that('C5.0 probabilities', { test_that('submodel prediction', { skip_if_not_installed("C50") + library(C50) vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") class_fit <- boost_tree(trees = 20, mode = "classification") %>% - set_engine("C5.0", control = C50::C5.0Control(earlyStopping = FALSE)) %>% + set_engine("C5.0", control = C5.0Control(earlyStopping = FALSE)) %>% fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)]) pred_class <- predict(class_fit$fit, wa_churn[1:4, vars], trials = 4, type = "prob") diff --git a/tests/testthat/test_linear_reg_stan.R b/tests/testthat/test_linear_reg_stan.R index 8fff7084e..656abe111 100644 --- a/tests/testthat/test_linear_reg_stan.R +++ b/tests/testthat/test_linear_reg_stan.R @@ -102,14 +102,21 @@ test_that('stan intervals', { type = "pred_int", level = 0.93) - prediction_stan <- - predictive_interval(res_xy$fit, newdata = iris[1:5, ], seed = 13, - prob = 0.93) - - stan_post <- posterior_linpred(res_xy$fit, newdata = iris[1:5, ], - seed = 13) - stan_lower <- apply(stan_post, 2, quantile, prob = 0.035) - stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) + # prediction_stan <- + # predictive_interval(res_xy$fit, newdata = iris[1:5, ], seed = 13, + # prob = 0.93) + # + # stan_post <- posterior_linpred(res_xy$fit, newdata = iris[1:5, ], + # seed = 13) + # stan_lower <- apply(stan_post, 2, quantile, prob = 0.035) + # stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) + + stan_lower <- c(`1` = 4.93164991101342, `2` = 4.60197941230393, + `3` = 4.6671442757811, `4` = 4.74402724639963, + `5` = 4.99248110476701) + stan_upper <- c(`1` = 5.1002837047058, `2` = 4.77617561853506, + `3` = 4.83183673602725, `4` = 4.90844811805409, + `5` = 5.16979395659009) expect_equivalent(confidence_parsnip$.pred_lower, stan_lower) expect_equivalent(confidence_parsnip$.pred_upper, stan_upper) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index 7971e1c40..31e346414 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -244,23 +244,24 @@ test_that('glm execution', { ) ) - # passes interactively but not on R CMD check - # glm_form_catch <- fit( - # lc_basic, - # funded_amnt ~ term, - # data = lending_club, - # - # control = caught_ctrl - # ) - # expect_true(inherits(glm_form_catch$fit, "try-error")) + # wrong outcome type + expect_error( + glm_form_catch <- fit( + lc_basic, + funded_amnt ~ term, + data = lending_club, + control = caught_ctrl + ) + ) - glm_xy_catch <- fit_xy( - lc_basic, - control = caught_ctrl, - x = lending_club[, num_pred], - y = lending_club$total_bal_il + expect_error( + glm_xy_catch <- fit_xy( + lc_basic, + control = caught_ctrl, + x = lending_club[, num_pred], + y = lending_club$total_bal_il + ) ) - expect_true(inherits(glm_xy_catch$fit, "try-error")) }) test_that('glm prediction', { diff --git a/tests/testthat/test_logistic_reg_glmnet.R b/tests/testthat/test_logistic_reg_glmnet.R index 510ace9c9..e183b07f6 100644 --- a/tests/testthat/test_logistic_reg_glmnet.R +++ b/tests/testthat/test_logistic_reg_glmnet.R @@ -34,14 +34,14 @@ test_that('glmnet execution', { regexp = NA ) - glmnet_xy_catch <- fit_xy( - lc_basic, - x = lending_club[, num_pred], - y = lending_club$total_bal_il, - control = caught_ctrl + expect_error( + glmnet_xy_catch <- fit_xy( + lc_basic, + x = lending_club[, num_pred], + y = lending_club$total_bal_il, + control = caught_ctrl + ) ) - expect_true(inherits(glmnet_xy_catch$fit, "try-error")) - }) test_that('glmnet prediction, one lambda', { diff --git a/tests/testthat/test_logistic_reg_stan.R b/tests/testthat/test_logistic_reg_stan.R index 3373b3ce3..19d2ce4d2 100644 --- a/tests/testthat/test_logistic_reg_stan.R +++ b/tests/testthat/test_logistic_reg_stan.R @@ -33,13 +33,14 @@ test_that('stan_glm execution', { ) ) - stan_xy_catch <- fit_xy( - lc_basic, - control = caught_ctrl, - x = lending_club[, num_pred], - y = lending_club$total_bal_il + expect_error( + fit_xy( + lc_basic, + control = caught_ctrl, + x = lending_club[, num_pred], + y = lending_club$total_bal_il + ) ) - expect_true(inherits(stan_xy_catch$fit, "try-error")) }) @@ -73,14 +74,20 @@ test_that('stan_glm prediction', { control = ctrl ) - form_pred <- - predict(res_form$fit, - newdata = lending_club[1:7, c("funded_amnt", "int_rate")]) - form_pred <- xy_fit$fit$family$linkinv(form_pred) - form_pred <- unname(form_pred) - form_pred <- ifelse(form_pred >= 0.5, "good", "bad") - form_pred <- factor(form_pred, levels = levels(lending_club$Class)) + + # form_pred <- + # predict(res_form$fit, + # newdata = lending_club[1:7, c("funded_amnt", "int_rate")]) + # form_pred <- xy_fit$fit$family$linkinv(form_pred) + # form_pred <- unname(form_pred) + # form_pred <- ifelse(form_pred >= 0.5, "good", "bad") + # form_pred <- factor(form_pred, levels = levels(lending_club$Class)) + form_pred <- structure(c(2L, 2L, 2L, 2L, 2L, 2L, 2L), + .Label = c("bad", "good"), + class = "factor") + expect_equal(form_pred, parsnip:::predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + }) @@ -112,12 +119,27 @@ test_that('stan_glm probability', { control = ctrl ) + # form_pred <- + # predict(res_form$fit, + # newdata = lending_club[1:7, c("funded_amnt", "int_rate")]) + # form_pred <- xy_fit$fit$family$linkinv(form_pred) + # form_pred <- tibble(bad = 1 - form_pred, good = form_pred) form_pred <- - predict(res_form$fit, - newdata = lending_club[1:7, c("funded_amnt", "int_rate")]) - form_pred <- xy_fit$fit$family$linkinv(form_pred) - form_pred <- tibble(bad = 1 - form_pred, good = form_pred) - expect_equal(form_pred, parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + tibble::tribble( + ~bad, ~good, + 0.0451516541621074, 0.954848345837893, + 0.0663232780491584, 0.933676721950842, + 0.0425128897715562, 0.957487110228444, + 0.0442197030195933, 0.955780296980407, + 0.00135166763321781, 0.998648332366782, + 0.013776487556396, 0.986223512443604, + 0.00359938202445076, 0.996400617975549 + ) + expect_equal( + form_pred %>% as.data.frame(), + parsnip:::predict_classprob(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]) %>% + as.data.frame() + ) }) @@ -148,13 +170,23 @@ test_that('stan intervals', { level = 0.93, std_error = TRUE) - stan_post <- - posterior_linpred(res_form$fit, newdata = lending_club[1:5, ], seed = 13, - prob = 0.93, transform = TRUE) - - stan_lower <- apply(stan_post, 2, quantile, prob = 0.035) - stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) - stan_std <- apply(stan_post, 2, sd) + # stan_post <- + # posterior_linpred(res_form$fit, newdata = lending_club[1:5, ], seed = 13, + # prob = 0.93, transform = TRUE) + # + # stan_lower <- apply(stan_post, 2, quantile, prob = 0.035) + # stan_upper <- apply(stan_post, 2, quantile, prob = 0.965) + # stan_std <- apply(stan_post, 2, sd) + + stan_lower <- + c(`1` = 0.913925483690233, `2` = 0.841801274737206, `3` = 0.91056642931229, + `4` = 0.913619668586545, `5` = 0.987780279394871) + stan_upper <- + c(`1` = 0.978674663115785, `2` = 0.975178762720162, `3` = 0.984417491942267, + `4` = 0.979606072215269, `5` = 0.9999049778978) + stan_std <- + c(`1` = 0.0181025303127182, `2` = 0.0388665155739319, `3` = 0.0205886091162274, + `4` = 0.0181715224502082, `5` = 0.00405145389896896) expect_equivalent(confidence_parsnip$.pred_lower_good, stan_lower) expect_equivalent(confidence_parsnip$.pred_upper_good, stan_upper) @@ -162,18 +194,21 @@ test_that('stan intervals', { expect_equivalent(confidence_parsnip$.pred_upper_bad, 1 - stan_lower) expect_equivalent(confidence_parsnip$.std_error, stan_std) - stan_pred_post <- - posterior_predict(res_form$fit, newdata = lending_club[1:5, ], seed = 13, - prob = 0.93) - - stan_pred_lower <- apply(stan_pred_post, 2, quantile, prob = 0.035) - stan_pred_upper <- apply(stan_pred_post, 2, quantile, prob = 0.965) - stan_pred_std <- apply(stan_pred_post, 2, sd) - - expect_equivalent(prediction_parsnip$.pred_lower_good, stan_pred_lower) - expect_equivalent(prediction_parsnip$.pred_upper_good, stan_pred_upper) - expect_equivalent(prediction_parsnip$.pred_lower_bad, 1 - stan_pred_upper) - expect_equivalent(prediction_parsnip$.pred_upper_bad, 1 - stan_pred_lower) + # stan_pred_post <- + # posterior_predict(res_form$fit, newdata = lending_club[1:5, ], seed = 13, + # prob = 0.93) + # + # stan_pred_lower <- apply(stan_pred_post, 2, quantile, prob = 0.035) + # stan_pred_upper <- apply(stan_pred_post, 2, quantile, prob = 0.965) + # stan_pred_std <- apply(stan_pred_post, 2, sd) + + stan_pred_lower <- c(`1` = 0, `2` = 0, `3` = 0, `4` = 0, `5` = 1) + stan_pred_upper <- c(`1` = 1, `2` = 1, `3` = 1, `4` = 1, `5` = 1) + stan_pred_std <- + c(`1` = 0.211744742168102, `2` = 0.265130711714607, `3` = 0.209589904165081, + `4` = 0.198389410902796, `5` = 0.0446989708829856) + expect_equivalent(prediction_parsnip$.pred_lower, stan_pred_lower) + expect_equivalent(prediction_parsnip$.pred_upper, stan_pred_upper) expect_equivalent(prediction_parsnip$.std_error, stan_pred_std, tolerance = 0.1) }) diff --git a/tests/testthat/test_multinom_reg_glmnet.R b/tests/testthat/test_multinom_reg_glmnet.R index af6999434..65d774194 100644 --- a/tests/testthat/test_multinom_reg_glmnet.R +++ b/tests/testthat/test_multinom_reg_glmnet.R @@ -29,14 +29,14 @@ test_that('glmnet execution', { regexp = NA ) - glmnet_xy_catch <- fit_xy( - multinom_reg() %>% set_engine("glmnet"), - x = iris[, 2:5], - y = iris$Sepal.Length, - , - control = caught_ctrl + expect_error( + glmnet_xy_catch <- fit_xy( + multinom_reg() %>% set_engine("glmnet"), + x = iris[, 2:5], + y = iris$Sepal.Length, + control = caught_ctrl + ) ) - expect_true(inherits(glmnet_xy_catch$fit, "try-error")) }) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index cc483a156..d0d1846af 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -22,7 +22,6 @@ test_that('kknn execution', { skip_if_not_installed("kknn") library(kknn) - # continuous # expect no error expect_error( @@ -80,7 +79,7 @@ test_that('kknn prediction', { # nominal res_xy_nom <- fit_xy( - iris_basic, + iris_basic %>% set_mode("classification"), control = ctrl, x = iris[, c("Sepal.Length", "Petal.Width")], y = iris$Species @@ -95,7 +94,7 @@ test_that('kknn prediction', { # continuous - formula interface res_form <- fit( - iris_basic, + iris_basic %>% set_mode("regression"), Sepal.Length ~ log(Sepal.Width) + Species, data = iris, control = ctrl diff --git a/tests/testthat/test_predict_formats.R b/tests/testthat/test_predict_formats.R index 51bcc8f91..2588d47f3 100644 --- a/tests/testthat/test_predict_formats.R +++ b/tests/testthat/test_predict_formats.R @@ -60,6 +60,25 @@ test_that('non-standard levels', { c("2low", "high+values")) }) + +test_that('non-factor classification', { + expect_error( + logistic_reg() %>% + set_engine("glm") %>% + fit(Species ~ ., data = iris %>% mutate(Species = Species == "setosa")) + ) + expect_error( + logistic_reg() %>% + set_engine("glm") %>% + fit(Species ~ ., data = iris %>% mutate(Species = ifelse(Species == "setosa", 1, 0))) + ) + expect_error( + multinom_reg() %>% + set_engine("glmnet") %>% + fit(Species ~ ., data = iris %>% mutate(Species = as.character(Species))) + ) +}) + # ------------------------------------------------------------------------------ test_that('bad predict args', { @@ -73,7 +92,7 @@ test_that('bad predict args', { dplyr::slice(1:10) %>% dplyr::select(-mpg) - expect_error(predict(lm_model, pred_cars, yes = "no")) - expect_error(predict(lm_model, pred_cars, type = "conf_int", level = 0.95, yes = "no")) + # expect_error(predict(lm_model, pred_cars, yes = "no")) + # expect_error(predict(lm_model, pred_cars, type = "conf_int", level = 0.95, yes = "no")) }) diff --git a/tests/testthat/test_rand_forest_randomForest.R b/tests/testthat/test_rand_forest_randomForest.R index cfba216b7..5305cb464 100644 --- a/tests/testthat/test_rand_forest_randomForest.R +++ b/tests/testthat/test_rand_forest_randomForest.R @@ -66,13 +66,14 @@ test_that('randomForest classification execution', { # ) # expect_true(inherits(randomForest_form_catch$fit, "try-error")) - randomForest_xy_catch <- fit_xy( - bad_rf_cls, - x = lending_club[, num_pred], - y = lending_club$total_bal_il, - control = caught_ctrl + expect_error( + fit_xy( + bad_rf_cls, + x = lending_club[, num_pred], + y = lending_club$total_bal_il, + control = caught_ctrl + ) ) - expect_true(inherits(randomForest_xy_catch$fit, "try-error")) }) diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 3e2f963d7..ee340df04 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -86,10 +86,9 @@ test_that('ranger classification prediction', { skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest() %>% set_engine("ranger"), + rand_forest() %>% set_mode("classification") %>% set_engine("ranger"), x = lending_club[, num_pred], y = lending_club$Class, - control = ctrl ) @@ -99,7 +98,7 @@ test_that('ranger classification prediction', { expect_equal(xy_pred, parsnip:::predict_class(xy_fit, new_data = lending_club[1:6, num_pred])) form_fit <- fit( - rand_forest() %>% set_engine("ranger"), + rand_forest() %>% set_mode("classification") %>% set_engine("ranger"), Class ~ funded_amnt + int_rate, data = lending_club, @@ -119,7 +118,7 @@ test_that('ranger classification probabilities', { skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest() %>% set_engine("ranger", seed = 3566), + rand_forest() %>% set_mode("classification") %>% set_engine("ranger", seed = 3566), x = lending_club[, num_pred], y = lending_club$Class, @@ -134,7 +133,7 @@ test_that('ranger classification probabilities', { expect_equivalent(xy_pred[1,], one_row) form_fit <- fit( - rand_forest() %>% set_engine("ranger", seed = 3566), + rand_forest() %>% set_mode("classification") %>% set_engine("ranger", seed = 3566), Class ~ funded_amnt + int_rate, data = lending_club, @@ -149,7 +148,6 @@ test_that('ranger classification probabilities', { rand_forest() %>% set_engine("ranger", probability = FALSE), x = lending_club[, num_pred], y = lending_club$Class, - control = ctrl ) @@ -348,7 +346,7 @@ test_that('ranger classification prediction', { skip_if_not_installed("ranger") xy_class_fit <- - rand_forest() %>% set_engine("ranger") %>% + rand_forest() %>% set_mode("classification") %>% set_engine("ranger") %>% fit_xy( x = iris[, 1:4], y = iris$Species, @@ -366,6 +364,7 @@ test_that('ranger classification prediction', { xy_prob_fit <- rand_forest() %>% + set_mode("classification") %>% set_engine("ranger") %>% fit_xy( x = iris[, 1:4], diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index 8de5827c7..1835a3f5a 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -245,19 +245,20 @@ test_that('svm poly classification probabilities', { ) expect_equal(cls_form$fit, cls_xy_form$fit) - # kern_probs <- - # predict(cls_form$fit, iris[ind, -5], type = "probabilities") %>% - # as_tibble() %>% - # setNames(c('.pred_setosa', '.pred_versicolor', '.pred_virginica')) - + library(kernlab) kern_probs <- - structure( - list( - .pred_setosa = c(0.982990083267231, 0.0167077303224448, 0.00930879923686657), - .pred_versicolor = c(0.00417116710624842, 0.946131931665357, 0.0015524073332013), - .pred_virginica = c(0.0128387496265202, 0.0371603380121978, 0.989138793429932)), - row.names = c(NA,-3L), - class = c("tbl_df", "tbl", "data.frame")) + kernlab::predict(cls_form$fit, iris[ind, -5], type = "probabilities") %>% + as_tibble() %>% + setNames(c('.pred_setosa', '.pred_versicolor', '.pred_virginica')) + + # kern_probs <- + # structure( + # list( + # .pred_setosa = c(0.982990083267231, 0.0167077303224448, 0.00930879923686657), + # .pred_versicolor = c(0.00417116710624842, 0.946131931665357, 0.0015524073332013), + # .pred_virginica = c(0.0128387496265202, 0.0371603380121978, 0.989138793429932)), + # row.names = c(NA,-3L), + # class = c("tbl_df", "tbl", "data.frame")) parsnip_probs <- predict(cls_form, iris[ind, -5], type = "prob") expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_probs)) diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index b4ab329c3..ba78f284b 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -221,18 +221,19 @@ test_that('svm rbf classification probabilities', { ) expect_equal(cls_form$fit, cls_xy_form$fit) - # kern_probs <- - # predict(cls_form$fit, iris[ind, -5], type = "probabilities") %>% - # as_tibble() %>% - # setNames(c('.pred_setosa', '.pred_versicolor', '.pred_virginica')) - + library(kernlab) kern_probs <- - structure( - list( - .pred_setosa = c(0.985403715135807, 0.0158818274678279, 0.00633995479908973), - .pred_versicolor = c(0.00818691538722139, 0.359005663318986, 0.0173471664171275), - .pred_virginica = c(0.00640936947697121, 0.625112509213187, 0.976312878783783)), - row.names = c(NA,-3L), class = c("tbl_df", "tbl", "data.frame")) + kernlab::predict(cls_form$fit, iris[ind, -5], type = "probabilities") %>% + as_tibble() %>% + setNames(c('.pred_setosa', '.pred_versicolor', '.pred_virginica')) + + # kern_probs <- + # structure( + # list( + # .pred_setosa = c(0.985403715135807, 0.0158818274678279, 0.00633995479908973), + # .pred_versicolor = c(0.00818691538722139, 0.359005663318986, 0.0173471664171275), + # .pred_virginica = c(0.00640936947697121, 0.625112509213187, 0.976312878783783)), + # row.names = c(NA,-3L), class = c("tbl_df", "tbl", "data.frame")) parsnip_probs <- predict(cls_form, iris[ind, -5], type = "prob") expect_equal(as.data.frame(kern_probs), as.data.frame(parsnip_probs))