diff --git a/NAMESPACE b/NAMESPACE index d33da2365..d238df397 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -177,6 +177,7 @@ export(tidy) export(translate) export(translate.default) export(update_dot_check) +export(update_engine_parameters) export(update_main_parameters) export(varying) export(varying_args) diff --git a/R/boost_tree.R b/R/boost_tree.R index 77fd30c76..d3f1cc31d 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -163,7 +163,8 @@ update.boost_tree <- loss_reduction = NULL, sample_size = NULL, stop_iter = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -185,12 +186,15 @@ update.boost_tree <- # TODO make these blocks into a function and document well if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/decision_tree.R b/R/decision_tree.R index 4bb0c7271..5c518771f 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -116,7 +116,8 @@ update.decision_tree <- parameters = NULL, cost_complexity = NULL, tree_depth = NULL, min_n = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -131,12 +132,15 @@ update.decision_tree <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/linear_reg.R b/R/linear_reg.R index b2e9e7892..619a3c761 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -131,7 +131,8 @@ update.linear_reg <- parameters = NULL, penalty = NULL, mixture = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -145,12 +146,15 @@ update.linear_reg <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 46f379658..049a19685 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -115,7 +115,8 @@ update.logistic_reg <- parameters = NULL, penalty = NULL, mixture = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -129,12 +130,15 @@ update.logistic_reg <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/mars.R b/R/mars.R index be8fb0e12..1e68efb95 100644 --- a/R/mars.R +++ b/R/mars.R @@ -93,7 +93,8 @@ update.mars <- parameters = NULL, num_terms = NULL, prod_degree = NULL, prune_method = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -109,12 +110,15 @@ update.mars <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/misc.R b/R/misc.R index 546ac1612..b8e4c329d 100644 --- a/R/misc.R +++ b/R/misc.R @@ -171,8 +171,7 @@ names0 <- function (num, prefix = "x") { #' @export #' @keywords internal #' @rdname add_on_exports -update_dot_check <- function(...) { - dots <- enquos(...) +update_dot_check <- function(dots) { if (length(dots) > 0) rlang::abort( glue::glue( @@ -282,5 +281,25 @@ update_main_parameters <- function(args, param) { args <- utils::modifyList(args, param) } +#' @export +#' @keywords internal +#' @rdname add_on_exports +update_engine_parameters <- function(eng_args, ...) { + + dots <- enquos(...) + + ## only update from dots when there are eng args in original model spec + if (is_null(eng_args)) { + ret <- NULL + } else { + ret <- utils::modifyList(eng_args, dots) + } + + has_extra_dots <- !(names(dots) %in% names(eng_args)) + dots <- dots[has_extra_dots] + update_dot_check(dots) + + ret +} diff --git a/R/mlp.R b/R/mlp.R index 2ca3565ac..422e1fcd3 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -120,7 +120,8 @@ update.mlp <- hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -139,12 +140,15 @@ update.mlp <- # TODO make these blocks into a function and document well if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/multinom_reg.R b/R/multinom_reg.R index d6245c980..48310ca48 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -114,7 +114,8 @@ update.multinom_reg <- parameters = NULL, penalty = NULL, mixture = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -128,12 +129,15 @@ update.multinom_reg <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 73e7dd81d..6988697a6 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -96,7 +96,8 @@ update.nearest_neighbor <- function(object, weight_func = NULL, dist_power = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -112,12 +113,15 @@ update.nearest_neighbor <- function(object, if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/rand_forest.R b/R/rand_forest.R index d3d1b690a..98177eae5 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -110,7 +110,8 @@ update.rand_forest <- parameters = NULL, mtry = NULL, trees = NULL, min_n = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -126,12 +127,15 @@ update.rand_forest <- # TODO make these blocks into a function and document well if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/surv_reg.R b/R/surv_reg.R index 3a85667dc..cced97abf 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -96,7 +96,8 @@ print.surv_reg <- function(x, ...) { #' @rdname surv_reg #' @export update.surv_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -110,12 +111,15 @@ update.surv_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALS if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/svm_poly.R b/R/svm_poly.R index b46cb63b8..5edc8c631 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -98,7 +98,8 @@ update.svm_poly <- cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -115,12 +116,15 @@ update.svm_poly <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/R/svm_rbf.R b/R/svm_rbf.R index a2568e21e..52503f72a 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -98,7 +98,8 @@ update.svm_rbf <- cost = NULL, rbf_sigma = NULL, margin = NULL, fresh = FALSE, ...) { - update_dot_check(...) + + eng_args <- update_engine_parameters(object$eng_args, ...) if (!is.null(parameters)) { parameters <- check_final_param(parameters) @@ -114,12 +115,15 @@ update.svm_rbf <- if (fresh) { object$args <- args + object$eng_args <- eng_args } else { null_args <- map_lgl(args, null_value) if (any(null_args)) args <- args[!null_args] if (length(args) > 0) object$args[names(args)] <- args + if (length(eng_args) > 0) + object$eng_args[names(eng_args)] <- eng_args } new_model_spec( diff --git a/man/add_on_exports.Rd b/man/add_on_exports.Rd index a1a47dd74..bbbdfd8db 100644 --- a/man/add_on_exports.Rd +++ b/man/add_on_exports.Rd @@ -7,6 +7,7 @@ \alias{new_model_spec} \alias{check_final_param} \alias{update_main_parameters} +\alias{update_engine_parameters} \alias{is_varying} \title{Functions required for parsnip-adjacent packages} \usage{ @@ -14,7 +15,7 @@ null_value(x) show_fit(model, eng) -update_dot_check(...) +update_dot_check(dots) new_model_spec(cls, args, eng_args, mode, method, engine) @@ -22,6 +23,8 @@ check_final_param(x) update_main_parameters(args, param) +update_engine_parameters(eng_args, ...) + is_varying(x) } \description{ diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index 724de2831..f368220d9 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -109,14 +109,19 @@ test_that('updating', { expr1 <- boost_tree() %>% set_engine("xgboost", verbose = 0) expr1_exp <- boost_tree(trees = 10) %>% set_engine("xgboost", verbose = 0) - expr2 <- boost_tree(trees = varying()) %>% set_engine("xgboost") - expr2_exp <- boost_tree(trees = varying()) %>% set_engine("xgboost", verbose = 0) + expr2 <- boost_tree(trees = varying()) %>% set_engine("C5.0", bands = varying()) + expr2_exp <- boost_tree(trees = varying()) %>% set_engine("C5.0", bands = 10) expr3 <- boost_tree(trees = 1, sample_size = varying()) expr3_exp <- boost_tree(trees = 1) + expr4 <- boost_tree() %>% set_engine("C5.0", noGlobalPruning = varying()) + expr4_exp <- boost_tree() %>% set_engine("C5.0", noGlobalPruning = TRUE) + expect_equal(update(expr1, trees = 10), expr1_exp) + expect_equal(update(expr2, bands = 10), expr2_exp) expect_equal(update(expr3, trees = 1, fresh = TRUE), expr3_exp) + expect_equal(update(expr4, noGlobalPruning = TRUE), expr4_exp) param_tibb <- tibble::tibble(trees = 7, mtry = 1) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_decision_tree.R b/tests/testthat/test_decision_tree.R index 18720e7af..da356a0a7 100644 --- a/tests/testthat/test_decision_tree.R +++ b/tests/testthat/test_decision_tree.R @@ -99,13 +99,14 @@ test_that('updating', { expr1 <- decision_tree() %>% set_engine("rpart", model = FALSE) expr1_exp <- decision_tree(cost_complexity = .1) %>% set_engine("rpart", model = FALSE) - expr2 <- decision_tree(cost_complexity = varying()) %>% set_engine("rpart") + expr2 <- decision_tree(cost_complexity = varying()) %>% set_engine("rpart", model = varying()) expr2_exp <- decision_tree(cost_complexity = varying()) %>% set_engine("rpart", model = FALSE) expr3 <- decision_tree(cost_complexity = 1, min_n = varying()) expr3_exp <- decision_tree(cost_complexity = 1) expect_equal(update(expr1, cost_complexity = .1), expr1_exp) + expect_equal(update(expr2, model = FALSE), expr2_exp) expect_equal(update(expr3, cost_complexity = 1, fresh = TRUE), expr3_exp) param_tibb <- tibble::tibble(cost_complexity = 0.1, min_n = 1) diff --git a/tests/testthat/test_linear_reg.R b/tests/testthat/test_linear_reg.R index d2574988d..ca304d0ea 100644 --- a/tests/testthat/test_linear_reg.R +++ b/tests/testthat/test_linear_reg.R @@ -166,11 +166,12 @@ test_that('updating', { expr1 <- linear_reg() %>% set_engine("lm", model = FALSE) expr1_exp <- linear_reg(mixture = 0) %>% set_engine("lm", model = FALSE) - expr2 <- linear_reg(mixture = varying()) %>% set_engine("glmnet") - expr2_exp <- linear_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10) + expr2 <- linear_reg() %>% set_engine("glmnet", nlambda = varying()) + expr2_exp <- linear_reg() %>% set_engine("glmnet", nlambda = 10) - expr3 <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet") - expr3_exp <- linear_reg(mixture = 1) %>% set_engine("glmnet") + expr3 <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet", nlambda = varying()) + expr3_exp <- linear_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet", nlambda = 10) + expr3_fre <- linear_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10) expr4 <- linear_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10) expr4_exp <- linear_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2) @@ -179,7 +180,9 @@ test_that('updating', { expr5_exp <- linear_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2) expect_equal(update(expr1, mixture = 0), expr1_exp) - expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp) + expect_equal(update(expr2, nlambda = 10), expr2_exp) + expect_equal(update(expr3, mixture = 1, fresh = TRUE, nlambda = 10), expr3_fre) + expect_equal(update(expr3, nlambda = 10), expr3_exp) param_tibb <- tibble::tibble(mixture = 1/3, penalty = 1) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_logistic_reg.R b/tests/testthat/test_logistic_reg.R index c23ca110e..ae3cca49d 100644 --- a/tests/testthat/test_logistic_reg.R +++ b/tests/testthat/test_logistic_reg.R @@ -181,11 +181,11 @@ test_that('updating', { expr1_exp <- logistic_reg(mixture = 0) %>% set_engine("glm", family = expr(binomial(link = "probit"))) - expr2 <- logistic_reg(mixture = varying()) %>% set_engine("glmnet") + expr2 <- logistic_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = varying()) expr2_exp <- logistic_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10) - expr3 <- logistic_reg(mixture = 0, penalty = varying()) - expr3_exp <- logistic_reg(mixture = 1) + expr3 <- logistic_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet", nlambda = varying()) + expr3_exp <- logistic_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10) expr4 <- logistic_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10) expr4_exp <- logistic_reg(mixture = 0) %>% set_engine("glmnet", nlambda = 10, pmax = 2) @@ -194,7 +194,8 @@ test_that('updating', { expr5_exp <- logistic_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2) expect_equal(update(expr1, mixture = 0), expr1_exp) - expect_equal(update(expr3, mixture = 1, fresh = TRUE), expr3_exp) + expect_equal(update(expr2, nlambda = 10), expr2_exp) + expect_equal(update(expr3, mixture = 1, fresh = TRUE, nlambda = 10), expr3_exp) param_tibb <- tibble::tibble(mixture = 1/3, penalty = 1) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_mars.R b/tests/testthat/test_mars.R index a95f95d64..c8e4c3f82 100644 --- a/tests/testthat/test_mars.R +++ b/tests/testthat/test_mars.R @@ -78,11 +78,12 @@ test_that('updating', { expr1 <- mars() %>% set_engine("earth", model = FALSE) expr1_exp <- mars(num_terms = 1) %>% set_engine("earth", model = FALSE) - expr2 <- mars(num_terms = varying()) %>% set_engine("earth") + expr2 <- mars(num_terms = varying()) %>% set_engine("earth", nk = varying()) expr2_exp <- mars(num_terms = varying()) %>% set_engine("earth", nk = 10) - expr3 <- mars(num_terms = 1, prod_degree = varying()) %>% set_engine("earth") - expr3_exp <- mars(num_terms = 1) %>% set_engine("earth") + expr3 <- mars(num_terms = 1, prod_degree = varying()) %>% set_engine("earth", nk = varying()) + expr3_fre <- mars(num_terms = 1) %>% set_engine("earth", nk = varying()) + expr3_exp <- mars(num_terms = 1) %>% set_engine("earth", nk = 10) expr4 <- mars(num_terms = 0) %>% set_engine("earth", nk = 10) expr4_exp <- mars(num_terms = 0) %>% set_engine("earth", nk = 10, trace = 2) @@ -91,7 +92,9 @@ test_that('updating', { expr5_exp <- mars(num_terms = 1) %>% set_engine("earth", nk = 10, trace = 2) expect_equal(update(expr1, num_terms = 1), expr1_exp) - expect_equal(update(expr3, num_terms = 1, fresh = TRUE), expr3_exp) + expect_equal(update(expr2, nk = 10), expr2_exp) + expect_equal(update(expr3, num_terms = 1, fresh = TRUE), expr3_fre) + expect_equal(update(expr3, num_terms = 1, fresh = TRUE, nk = 10), expr3_exp) param_tibb <- tibble::tibble(num_terms = 3, prod_degree = 1) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index 80b63e114..92ac55e31 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -131,21 +131,23 @@ test_that('updating', { expr1_exp <- mlp(mode = "regression", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) - expr2 <- mlp(mode = "regression", hidden_units = 7) %>% set_engine("nnet") - expr2_exp <- mlp(mode = "regression", hidden_units = 7) %>% set_engine("nnet", Hess = FALSE) + expr2 <- mlp(mode = "regression") %>% set_engine("nnet", Hess = varying()) + expr2_exp <- mlp(mode = "regression") %>% set_engine("nnet", Hess = FALSE) expr3 <- mlp(mode = "regression", hidden_units = 7, epochs = varying()) %>% set_engine("keras") expr3_exp <- mlp(mode = "regression", hidden_units = 2) %>% set_engine("keras") expr4 <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) - expr4_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) + expr4_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = 1e-3) expr5 <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE) expr5_exp <- mlp(mode = "classification", hidden_units = 2) %>% set_engine("nnet", Hess = FALSE, abstol = varying()) expect_equal(update(expr1, hidden_units = 2), expr1_exp) + expect_equal(update(expr2, Hess = FALSE), expr2_exp) expect_equal(update(expr3, hidden_units = 2, fresh = TRUE), expr3_exp) + expect_equal(update(expr4, abstol = 1e-3), expr4_exp) param_tibb <- tibble::tibble(hidden_units = 3, dropout = .1) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_multinom_reg.R b/tests/testthat/test_multinom_reg.R index fbec18322..a373159b6 100644 --- a/tests/testthat/test_multinom_reg.R +++ b/tests/testthat/test_multinom_reg.R @@ -80,7 +80,7 @@ test_that('updating', { expr1 <- multinom_reg() %>% set_engine("glmnet", intercept = TRUE) expr1_exp <- multinom_reg(mixture = 0) %>% set_engine("glmnet", intercept = TRUE) - expr2 <- multinom_reg(mixture = varying()) %>% set_engine("glmnet") + expr2 <- multinom_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = varying()) expr2_exp <- multinom_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10) expr3 <- multinom_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet") @@ -92,8 +92,8 @@ test_that('updating', { expr5 <- multinom_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10) expr5_exp <- multinom_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2) - # expect_equal(update(expr1 %>% set_engine("glmnet"), mixture = 0), expr1_exp) - expect_equal(update(expr2) %>% set_engine("glmnet", nlambda = 10), expr2_exp) + expect_equal(update(expr1, mixture = 0), expr1_exp) + expect_equal(update(expr2, nlambda = 10), expr2_exp) expect_equal(update(expr3, mixture = 1, fresh = TRUE) %>% set_engine("glmnet"), expr3_exp) # expect_equal(update(expr4 %>% set_engine("glmnet", pmax = 2)), expr4_exp) expect_equal(update(expr5) %>% set_engine("glmnet", nlambda = 10, pmax = 2), expr5_exp) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 9d67c6d89..9bb06b4a0 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -84,15 +84,15 @@ test_that('updating', { expr1 <- nearest_neighbor() %>% set_engine("kknn", scale = FALSE) expr1_exp <- nearest_neighbor(neighbors = 5) %>% set_engine("kknn", scale = FALSE) - expr2 <- nearest_neighbor(neighbors = varying()) %>% set_engine("kknn") - expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") %>% set_engine("kknn") + expr2 <- nearest_neighbor(neighbors = varying()) %>% set_engine("kknn", scale = varying()) + expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") %>% set_engine("kknn", scale = FALSE) - expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine("kknn") - expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine("kknn") + expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine("kknn", scale = varying()) + expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine("kknn", scale = FALSE) - expect_equal(update(expr1, neighbors = 5), expr1_exp) - expect_equal(update(expr2, weight_func = "triangular"), expr2_exp) - expect_equal(update(expr3, neighbors = 3, fresh = TRUE), expr3_exp) + expect_equal(update(expr1, neighbors = 5, scale = FALSE), expr1_exp) + expect_equal(update(expr2, weight_func = "triangular", scale = FALSE), expr2_exp) + expect_equal(update(expr3, neighbors = 3, fresh = TRUE, scale = FALSE), expr3_exp) param_tibb <- tibble::tibble(neighbors = 7, dist_power = 1) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_rand_forest.R b/tests/testthat/test_rand_forest.R index 2ff65e0c2..c6654967a 100644 --- a/tests/testthat/test_rand_forest.R +++ b/tests/testthat/test_rand_forest.R @@ -307,15 +307,17 @@ test_that('updating', { expr4 <- rand_forest(mode = "regression", mtry = 2) %>% set_engine("randomForest", norm.votes = FALSE, sampsize = varying()) expr4_exp <- rand_forest(mode = "regression", mtry = 2) %>% - set_engine("randomForest", norm.votes = TRUE, sampsize = varying()) + set_engine("randomForest", norm.votes = TRUE, sampsize = 10) - expr5 <- rand_forest(mode = "regression", mtry = 2) %>% - set_engine("randomForest", norm.votes = FALSE) - expr5_exp <- rand_forest(mode = "regression", mtry = 2) %>% - set_engine("randomForest", norm.votes = TRUE, sampsize = varying()) + expr5 <- rand_forest(mode = "regression") %>% + set_engine("randomForest", norm.votes = varying()) + expr5_exp <- rand_forest(mode = "regression") %>% + set_engine("randomForest", norm.votes = TRUE) expect_equal(update(expr1, mtry = 2), expr1_exp) expect_equal(update(expr3, mtry = 2, fresh = TRUE), expr3_exp) + expect_equal(update(expr4, sampsize = 10, norm.votes = TRUE), expr4_exp) + expect_equal(update(expr5, norm.votes = TRUE), expr5_exp) param_tibb <- tibble::tibble(mtry = 3, trees = 10) param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_surv_reg.R b/tests/testthat/test_surv_reg.R index 9cd0c1872..2e5674c9d 100644 --- a/tests/testthat/test_surv_reg.R +++ b/tests/testthat/test_surv_reg.R @@ -60,9 +60,9 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- surv_reg() %>% set_engine("flexsurv", cl = .99) + expr1 <- surv_reg() %>% set_engine("flexsurv", cl = varying()) expr1_exp <- surv_reg(dist = "lnorm") %>% set_engine("flexsurv", cl = .99) - expect_equal(update(expr1, dist = "lnorm"), expr1_exp) + expect_equal(update(expr1, dist = "lnorm", cl = 0.99), expr1_exp) param_tibb <- tibble::tibble(dist = "weibull") param_list <- as.list(param_tibb) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index f22caac9d..fef562ccd 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -80,14 +80,14 @@ test_that('updating', { expr1 <- svm_poly(mode = "regression") %>% set_engine("kernlab", cross = 10) expr1_exp <- svm_poly(mode = "regression", degree = 1) %>% set_engine("kernlab", cross = 10) - expr2 <- svm_poly(mode = "regression", degree = varying()) %>% set_engine("kernlab") - expr2_exp <- svm_poly(mode = "regression", degree = varying(), scale_factor = 1) %>% set_engine("kernlab") + expr2 <- svm_poly(mode = "regression", degree = varying()) %>% set_engine("kernlab", cross = varying()) + expr2_exp <- svm_poly(mode = "regression", degree = varying(), scale_factor = 1) %>% set_engine("kernlab", cross = 10) expr3 <- svm_poly(mode = "regression", degree = 2, scale_factor = varying()) %>% set_engine("kernlab") expr3_exp <- svm_poly(mode = "regression", degree = 3) %>% set_engine("kernlab") expect_equal(update(expr1, degree = 1), expr1_exp) - expect_equal(update(expr2, scale_factor = 1), expr2_exp) + expect_equal(update(expr2, scale_factor = 1, cross = 10), expr2_exp) expect_equal(update(expr3, degree = 3, fresh = TRUE), expr3_exp) param_tibb <- tibble::tibble(degree = 3, cost = 10) diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index 663c5d0e2..059754fbf 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -62,11 +62,13 @@ test_that('updating', { expr1 <- svm_rbf(mode = "regression") %>% set_engine("kernlab", cross = 10) expr1_exp <- svm_rbf(mode = "regression", rbf_sigma = .1) %>% set_engine("kernlab", cross = 10) - + expr2 <- svm_rbf(mode = "regression") %>% set_engine("kernlab", cross = varying()) + expr2_exp <- svm_rbf(mode = "regression") %>% set_engine("kernlab", cross = 10) expr3 <- svm_rbf(mode = "regression", rbf_sigma = .2) %>% set_engine("kernlab") expr3_exp <- svm_rbf(mode = "regression", rbf_sigma = .3) %>% set_engine("kernlab") expect_equal(update(expr1, rbf_sigma = .1), expr1_exp) + expect_equal(update(expr2, cross = 10), expr2_exp) expect_equal(update(expr3, rbf_sigma = .3, fresh = TRUE), expr3_exp) param_tibb <- tibble::tibble(rbf_sigma = 3, cost = 10)