From 3367fab47cb8622f67b8e94f4423f2f0326c98db Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 5 May 2019 14:01:26 -0400 Subject: [PATCH 01/64] prototype registration methods for #167 --- NAMESPACE | 15 +++ R/aaa_models.R | 257 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 R/aaa_models.R diff --git a/NAMESPACE b/NAMESPACE index f24856ce0..92870a2de 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -78,13 +78,19 @@ export(.y) export(C5.0_train) export(add_rowindex) export(boost_tree) +export(check_arg_val) export(check_empty_ellipse) +export(check_engine_val) +export(check_func_val) +export(check_mod_val) +export(check_mode_val) export(decision_tree) export(fit) export(fit.model_spec) export(fit_control) export(fit_xy) export(fit_xy.model_spec) +export(get_model_env) export(keras_mlp) export(linear_reg) export(logistic_reg) @@ -99,16 +105,25 @@ export(null_model) export(nullmodel) export(predict.model_fit) export(rand_forest) +export(register_dependency) +export(register_fit) +export(register_model_arg) +export(register_model_engine) +export(register_model_mode) +export(register_new_model) +export(register_pred) export(rpart_train) export(set_args) export(set_engine) export(set_mode) export(show_call) +export(show_model_info) export(surv_reg) export(svm_poly) export(svm_rbf) export(tidy.model_fit) export(translate) +export(validate_model) export(varying) export(varying_args) export(xgb_train) diff --git a/R/aaa_models.R b/R/aaa_models.R new file mode 100644 index 000000000..785e9ce9a --- /dev/null +++ b/R/aaa_models.R @@ -0,0 +1,257 @@ + +# initialize model environment + +parsnip <- rlang::new_environment() +parsnip$models <- NULL +parsnip$modes <- c("regression", "classification", "unknown") + +# ------------------------------------------------------------------------------ + +#' @export +get_model_env <- function() { + current <- utils::getFromNamespace("parsnip", ns = "parsnip") + # current <- get("parsnip") + current +} + +#' @export +check_mod_val <- function(mod, new = FALSE, existance = FALSE) { + if (is_missing(mod) || length(mod) != 1) + stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", + call. = FALSE) + + if (new | existance) { + current <- get_model_env() + } + + if (new) { + if (any(current$models == mod)) { + stop("Model `", mod, "` already exists", call. = FALSE) + } + } + + if (existance) { + current <- get_model_env() + if (!any(current$models == mod)) { + stop("Model `", mod, "` has not been registered.", call. = FALSE) + } + } + + invisible(NULL) +} + +#' @export +check_mode_val <- function(mode) { + if (is_missing(mode) || length(mode) != 1) + stop("Please supply a character string for a mode (e.g. `'regression'`)", + call. = FALSE) + invisible(NULL) +} + +#' @export +check_engine_val <- function(eng) { + if (is_missing(eng) || length(eng) != 1) + stop("Please supply a character string for an engine (e.g. `'lm'`)", + call. = FALSE) + invisible(NULL) +} + +#' @export +check_arg_val <- function(arg) { + if (is_missing(arg) || length(arg) != 1) + stop("Please supply a character string for the argument", + call. = FALSE) + invisible(NULL) +} + +#' @export +check_func_val <- function(func) { + msg <- + paste( + "`func` should be a named list with names 'pkg' and 'fun' and these", + "should both be single character strings" + ) + + if (is_missing(func) || !is.list(func) || length(func) != 2) + stop(msg, call. = FALSE) + + nms <- sort(names(func)) + if (!isTRUE(all.equal(nms, c("fun", "pkg")))) { + stop(msg, call. = FALSE) + } + if (!all(purrr::map_lgl(func, is.character))) { + stop(msg, call. = FALSE) + } + + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @export +register_new_model <- function(mod) { + check_mod_val(mod, new = TRUE) + + current <- get_model_env() + + current$models <- c(current$models, mod) + current[[mod]] <- dplyr::tibble(engine = character(0), mode = character(0)) + current[[paste0(mod, "_pkg")]] <- character(0) + current[[paste0(mod, "_modes")]] <- "unknown" + current[[paste0(mod, "_args")]] <- + dplyr::tibble( + engine = character(0), + parsnip = character(0), + original = character(0), + func = list() + ) + current[[paste0(mod, "_fit")]] <- list() + current[[paste0(mod, "_predict")]] <- list() + + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @export +register_model_mode <- function(mod, mode) { + check_mod_val(mod, existance = TRUE) + check_mode_val(mode) + + current <- get_model_env() + + if (!any(current$modes == mode)) { + current$modes <- unique(c(current$modes, mode)) + } + current[[paste0(mod, "_modes")]] <- + unique(c(current[[paste0(mod, "_modes")]], mode)) + + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @export +register_model_engine <- function(mod, mode, eng) { + check_mod_val(mod, existance = TRUE) + check_mode_val(mode) + check_mode_val(eng) + + current <- get_model_env() + + new_eng <- dplyr::tibble(engine = eng, mode = mode) + old_eng <- current[[mod]] + engs <- + old_eng %>% + dplyr::bind_rows(new_eng) %>% + dplyr::distinct() + + current[[mod]] <- engs + + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ + +#' @export +register_model_arg <- function(mod, eng, val, original, func) { + check_mod_val(mod, existance = TRUE) + check_arg_val(val) + check_arg_val(original) + check_func_val(func) + + current <- get_model_env() + old_args <- current[[paste0(mod, "_args")]] + + new_arg <- + dplyr::tibble( + engine = eng, + parsnip = val, + original = original, + func = list(func) + ) + + updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new argument.", call. = FALSE) + } + + updated <- dplyr::distinct(updated, engine, parsnip, original) + + current[[paste0(mod, "_args")]] <- updated + + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ + +#' @export +register_dependency <- function(mod, pkg) { + +} + +# ------------------------------------------------------------------------------ + +#' @export +register_fit <- function(mod, mode, eng, info) { + +} + + +# ------------------------------------------------------------------------------ + +#' @export +register_pred <- function(mod, mode, eng, type, info) { + +} + +# ------------------------------------------------------------------------------ + +#' @export +validate_model <- function(mod) { + # check for consistency across engines, modes, args, etc +} + +# ------------------------------------------------------------------------------ + +#' @export +show_model_info <- function(mod) { + check_mod_val(mod, existance = TRUE) + current <- get_model_env() + + cat("Information for `", mod, "`\n", sep = "") + + cat( + " modes:", + paste0(current[[paste0(mod, "_modes")]], collapse = ", "), + "\n" + ) + + engines <- current[[paste0(mod)]] + if (nrow(engines) > 0) { + cat(" engines: ") + engines %>% + dplyr::mutate(lab = paste0(engine, " (", mode, ")\n")) %>% + dplyr::pull(lab) %>% + cat(sep = "") + } else { + cat(" no registered engines yet.") + } + + args <- current[[paste0(mod, "_args")]] + if (nrow(args) > 0) { + cat(" arguments: \n") + args %>% + dplyr::select(engine, parsnip, original) %>% + dplyr::distinct() %>% + print() + } else { + cat(" no registered arguments yet.") + } + + invisible(NULL) +} + +# ------------------------------------------------------------------------------ From fee8383dd05f6740c773425d1e75a1b5e9b3ea92 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 5 May 2019 17:08:25 -0400 Subject: [PATCH 02/64] a little better printing --- R/aaa_models.R | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 785e9ce9a..c7e0b3926 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -231,9 +231,19 @@ show_model_info <- function(mod) { engines <- current[[paste0(mod)]] if (nrow(engines) > 0) { - cat(" engines: ") + cat(" engines: \n") engines %>% - dplyr::mutate(lab = paste0(engine, " (", mode, ")\n")) %>% + dplyr::mutate( + mode = format(paste0(mode, ": ")) + ) %>% + dplyr::group_by(mode) %>% + dplyr::summarize( + engine = paste0(sort(engine), collapse = ", ") + ) %>% + dplyr::mutate( + lab = paste0(" ", mode, engine, "\n") + ) %>% + dplyr::ungroup() %>% dplyr::pull(lab) %>% cat(sep = "") } else { @@ -246,7 +256,19 @@ show_model_info <- function(mod) { args %>% dplyr::select(engine, parsnip, original) %>% dplyr::distinct() %>% - print() + dplyr::mutate( + engine = format(paste0(" ", engine, ": ")), + parsnip = paste0(" ", format(parsnip), " --> ", original, "\n") + ) %>% + dplyr::group_by(engine) %>% + dplyr::mutate( + engine2 = ifelse(dplyr::row_number() == 1, engine, ""), + parsnip = ifelse(dplyr::row_number() == 1, paste0("\n", parsnip), parsnip), + lab = paste0(engine2, parsnip) + ) %>% + dplyr::ungroup() %>% + dplyr::pull(lab) %>% + cat(sep = "") } else { cat(" no registered arguments yet.") } From 435ee57927b4a3c2a077e8743fd9b926f15765a1 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 5 May 2019 17:08:47 -0400 Subject: [PATCH 03/64] testing out registration of a model --- R/rand_forest_data.R | 83 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 65eb84864..51ab572b5 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -278,3 +278,86 @@ rand_forest_spark_data <- ) ) ) + + +# ------------------------------------------------------------------------------ + +register_new_model("rand_forest") + +register_model_mode("rand_forest", "classification") +register_model_mode("rand_forest", "regression") + +register_model_engine("rand_forest", "classification", "randomForest") +register_model_engine("rand_forest", "classification", "ranger") +register_model_engine("rand_forest", "classification", "spark") + +register_model_engine("rand_forest", "regression", "randomForest") +register_model_engine("rand_forest", "regression", "ranger") +register_model_engine("rand_forest", "regression", "spark") + +register_model_arg( + mod = "rand_forest", + eng = "randomForest", + val = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry") +) +register_model_arg( + mod = "rand_forest", + eng = "randomForest", + val = "trees", + original = "ntree", + func = list(pkg = "dials", fun = "trees") +) +register_model_arg( + mod = "rand_forest", + eng = "randomForest", + val = "min_n", + original = "nodesize", + func = list(pkg = "dials", fun = "min_n") +) + +register_model_arg( + mod = "rand_forest", + eng = "ranger", + val = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry") +) +register_model_arg( + mod = "rand_forest", + eng = "ranger", + val = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees") +) +register_model_arg( + mod = "rand_forest", + eng = "ranger", + val = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n") +) + +register_model_arg( + mod = "rand_forest", + eng = "spark", + val = "mtry", + original = "feature_subset_strategy", + func = list(pkg = "dials", fun = "mtry") +) +register_model_arg( + mod = "rand_forest", + eng = "spark", + val = "trees", + original = "num_trees", + func = list(pkg = "dials", fun = "trees") +) +register_model_arg( + mod = "rand_forest", + eng = "spark", + val = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n") +) + From 67050742e001761d2e2a93fa21a016c859e9c57a Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 5 May 2019 19:22:39 -0400 Subject: [PATCH 04/64] moved to aaa_models --- R/predict.R | 3 --- 1 file changed, 3 deletions(-) diff --git a/R/predict.R b/R/predict.R index 7cea3e0b7..414096d1a 100644 --- a/R/predict.R +++ b/R/predict.R @@ -154,9 +154,6 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) res } -pred_types <- - c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile") - #' @importFrom glue glue_collapse check_pred_type <- function(object, type) { if (is.null(type)) { From d8e093a51121093d55a58c7677d98bfe673340b7 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 5 May 2019 19:23:00 -0400 Subject: [PATCH 05/64] dependencies, fit, and predict --- NAMESPACE | 2 + R/aaa_models.R | 222 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 216 insertions(+), 8 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 92870a2de..3f645168b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -81,9 +81,11 @@ export(boost_tree) export(check_arg_val) export(check_empty_ellipse) export(check_engine_val) +export(check_fit_info) export(check_func_val) export(check_mod_val) export(check_mode_val) +export(check_pred_info) export(decision_tree) export(fit) export(fit.model_spec) diff --git a/R/aaa_models.R b/R/aaa_models.R index c7e0b3926..468bc1478 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -7,6 +7,11 @@ parsnip$modes <- c("regression", "classification", "unknown") # ------------------------------------------------------------------------------ +pred_types <- + c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile") + +# ------------------------------------------------------------------------------ + #' @export get_model_env <- function() { current <- utils::getFromNamespace("parsnip", ns = "parsnip") @@ -68,17 +73,30 @@ check_arg_val <- function(arg) { check_func_val <- function(func) { msg <- paste( - "`func` should be a named list with names 'pkg' and 'fun' and these", - "should both be single character strings" + "`func` should be a named vector with element 'fun' and the optional ", + "element 'pkg'. These should both be single character strings." ) - if (is_missing(func) || !is.list(func) || length(func) != 2) + if (is_missing(func) || !is.vector(func) || length(func) > 2) stop(msg, call. = FALSE) nms <- sort(names(func)) - if (!isTRUE(all.equal(nms, c("fun", "pkg")))) { + + if (all(is.null(nms))) { stop(msg, call. = FALSE) } + + if (length(func) == 1) { + if (isTRUE(any(nms != "fun"))) { + stop(msg, call. = FALSE) + } + } else { + if (!isTRUE(all.equal(nms, c("fun", "pkg")))) { + stop(msg, call. = FALSE) + } + } + + if (!all(purrr::map_lgl(func, is.character))) { stop(msg, call. = FALSE) } @@ -86,6 +104,79 @@ check_func_val <- function(func) { invisible(NULL) } +#' @export +check_fit_info <- function(x) { + if (is.null(x)) { + stop("The `fit` module cannot be NULL.", call. = FALSE) + } + exp_nms <- c("defaults", "func", "interface", "protect") + if (!isTRUE(all.equal(sort(names(x)), exp_nms))) { + stop("The `fit` module should have elements: ", + paste0("`", exp_nms, "`", collapse = ", "), + call. = FALSE) + } + + exp_interf <- c("data.frame", "formula", "matrix") + if (length(x$interface) > 1) { + stop("The `interface` element should have a single value of : ", + paste0("`", exp_interf, "`", collapse = ", "), + call. = FALSE) + } + if (!any(x$interface == exp_interf)) { + stop("The `interface` element should have a value of : ", + paste0("`", exp_interf, "`", collapse = ", "), + call. = FALSE) + } + check_func_val(x$func) + + if (!is.list(x$defaults)) { + stop("The `defaults` element should be a list: ", call. = FALSE) + } + + invisible(NULL) +} + +#' @export +check_pred_info <- function(x, type) { + if (all(type != pred_types)) { + stop("The prediction type should be one of: ", + paste0("'", pred_types, "'", collapse = ", "), + call. = FALSE) + } + + exp_nms <- c("args", "func", "post", "pre") + if (!isTRUE(all.equal(sort(names(x)), exp_nms))) { + stop("The `predict` module should have elements: ", + paste0("`", exp_nms, "`", collapse = ", "), + call. = FALSE) + } + + if (!is.null(x$pre) & !is.function(x$pre)) { + stop("The `pre` module should be null or a function: ", + call. = FALSE) + } + if (!is.null(x$post) & !is.function(x$post)) { + stop("The `post` module should be null or a function: ", + call. = FALSE) + } + + check_func_val(x$func) + + if (!is.list(x$args)) { + stop("The `args` element should be a list. ", call. = FALSE) + } + + invisible(NULL) +} + + +#' @export +check_pkg_val <- function(x) { + if (is_missing(x) || length(x) != 1 || !is.character(x)) + stop("Please supply a single character vale for the package name", + call. = FALSE) + invisible(NULL) +} # ------------------------------------------------------------------------------ #' @export @@ -105,8 +196,19 @@ register_new_model <- function(mod) { original = character(0), func = list() ) - current[[paste0(mod, "_fit")]] <- list() - current[[paste0(mod, "_predict")]] <- list() + current[[paste0(mod, "_fit")]] <- + dplyr::tibble( + engine = character(0), + mode = character(0), + value = list() + ) + current[[paste0(mod, "_predict")]] <- + dplyr::tibble( + engine = character(0), + mode = character(0), + type = character(0), + value = list() + ) invisible(NULL) } @@ -189,22 +291,119 @@ register_model_arg <- function(mod, eng, val, original, func) { #' @export register_dependency <- function(mod, pkg) { + check_mod_val(mod, existance = TRUE) + check_pkg_val(pkg) + + current <- get_model_env() + current[[paste0(mod, "_pkg")]] <- + unique(c(current[[paste0(mod, "_pkg")]], pkg)) + + invisible(NULL) } # ------------------------------------------------------------------------------ #' @export -register_fit <- function(mod, mode, eng, info) { +register_fit <- function(mod, mode, eng, value) { + check_mod_val(mod, existance = TRUE) + check_mode_val(mode) + check_engine_val(eng) + check_fit_info(value) + + current <- get_model_env() + model_info <- current[[paste0(mod)]] + old_fits <- current[[paste0(mod, "_fit")]] + + has_engine <- + model_info %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + if (has_engine != 1) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' has not been registered for model '", + mod, "'. ", call. = FALSE) + } + + has_fit <- + old_fits %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + + if (has_fit > 0) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' already has a fit component for model '", + mod, "'. ", call. = FALSE) + } + + new_fit <- + dplyr::tibble( + engine = eng, + mode = mode, + value = list(value) + ) + + updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new fit module", call. = FALSE) + } + + current[[paste0(mod, "_fit")]] <- updated + invisible(NULL) } # ------------------------------------------------------------------------------ #' @export -register_pred <- function(mod, mode, eng, type, info) { +register_pred <- function(mod, mode, eng, type, value) { + check_mod_val(mod, existance = TRUE) + check_mode_val(mode) + check_engine_val(eng) + check_pred_info(value, type) + + current <- get_model_env() + model_info <- current[[paste0(mod)]] + old_fits <- current[[paste0(mod, "_predict")]] + + has_engine <- + model_info %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + if (has_engine != 1) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' has not been registered for model '", + mod, "'. ", call. = FALSE) + } + + has_pred <- + old_fits %>% + dplyr::filter(engine == eng & mode == !!mode & type == !!type) %>% + nrow() + if (has_pred > 0) { + stop("The combination of engine '", eng, "', mode '", + mode, "', and type '", type, + "' already has a prediction component for model '", + mod, "'. ", call. = FALSE) + } + + new_fit <- + dplyr::tibble( + engine = eng, + mode = mode, + type = type, + value = list(value) + ) + + updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new fit module", call. = FALSE) + } + current[[paste0(mod, "_predict")]] <- updated + + invisible(NULL) } # ------------------------------------------------------------------------------ @@ -273,6 +472,13 @@ show_model_info <- function(mod) { cat(" no registered arguments yet.") } + fits <- current[[paste0(mod, "_fits")]] + if (nrow(fits) > 0) { + + } else { + cat(" no registered fit modules yet.") + } + invisible(NULL) } From 0dc6e92e98d19a90d89ba8b70c3c80e34c94c151 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 5 May 2019 19:23:15 -0400 Subject: [PATCH 06/64] testing to see if rf could be registered --- R/rand_forest_data.R | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 51ab572b5..c96597a46 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -44,7 +44,7 @@ ranger_num_confint <- function(object, new_data, ...) { res$.pred_upper <- res$.pred + const * std_error res$.pred <- NULL - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$confint$extras$std_error) res$.std_error <- std_error res } @@ -73,20 +73,20 @@ ranger_class_confint <- function(object, new_data, ...) { col_names <- paste0(c(".pred_lower_", ".pred_upper_"), lvl) res <- res[, col_names] - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$confint$extras$std_error) res <- bind_cols(res, std_error) res } ranger_confint <- function(object, new_data, ...) { - if(object$fit$forest$treetype == "Regression") { + if (object$fit$forest$treetype == "Regression") { res <- ranger_num_confint(object, new_data, ...) } else { - if(object$fit$forest$treetype == "Probability estimation") { + if (object$fit$forest$treetype == "Probability estimation") { res <- ranger_class_confint(object, new_data, ...) } else { - stop ("Cannot compute confidence intervals for a ranger forest ", + stop("Cannot compute confidence intervals for a ranger forest ", "of type ", object$fit$forest$treetype, ".", call. = FALSE) } } @@ -159,6 +159,7 @@ rand_forest_ranger_data <- ), raw = list( pre = NULL, + post = NULL, func = c(fun = "predict"), args = list( @@ -361,3 +362,23 @@ register_model_arg( func = list(pkg = "dials", fun = "min_n") ) +register_fit(mod = "rand_forest", eng = "ranger", mode = "classification", + value = rand_forest_ranger_data$fit) + +register_fit(mod = "rand_forest", eng = "ranger", mode = "regression", + value = rand_forest_ranger_data$fit) + +register_pred(mod = "rand_forest", eng = "ranger", mode = "classification", + type = "class", value = parsnip:::rand_forest_ranger_data$class) + +register_pred(mod = "rand_forest", eng = "ranger", mode = "classification", + type = "prob", value = parsnip:::rand_forest_ranger_data$classprob) + +register_pred(mod = "rand_forest", eng = "ranger", mode = "classification", + type = "raw", value = parsnip:::rand_forest_ranger_data$raw) + +register_pred(mod = "rand_forest", eng = "ranger", mode = "regression", + type = "numeric", value = parsnip:::rand_forest_ranger_data$numeric) + +register_pred(mod = "rand_forest", eng = "ranger", mode = "regression", + type = "raw", value = parsnip:::rand_forest_ranger_data$raw) From 10d51a97177ae765998c691303b2cf643a35426a Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Wed, 8 May 2019 20:39:58 -0400 Subject: [PATCH 07/64] changed register to set; restructured the pkg information --- NAMESPACE | 18 ++++--- R/aaa_models.R | 111 ++++++++++++++++++++++++++++++++++++++----- R/rand_forest_data.R | 68 +++++++++++++------------- 3 files changed, 145 insertions(+), 52 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 3f645168b..98b71e662 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -85,6 +85,7 @@ export(check_fit_info) export(check_func_val) export(check_mod_val) export(check_mode_val) +export(check_pkg_val) export(check_pred_info) export(decision_tree) export(fit) @@ -92,7 +93,10 @@ export(fit.model_spec) export(fit_control) export(fit_xy) export(fit_xy.model_spec) +export(get_dependency) +export(get_fit) export(get_model_env) +export(get_pred_type) export(keras_mlp) export(linear_reg) export(logistic_reg) @@ -107,17 +111,17 @@ export(null_model) export(nullmodel) export(predict.model_fit) export(rand_forest) -export(register_dependency) -export(register_fit) -export(register_model_arg) -export(register_model_engine) -export(register_model_mode) -export(register_new_model) -export(register_pred) export(rpart_train) export(set_args) +export(set_dependency) export(set_engine) +export(set_fit) export(set_mode) +export(set_model_arg) +export(set_model_engine) +export(set_model_mode) +export(set_new_model) +export(set_pred) export(show_call) export(show_model_info) export(surv_reg) diff --git a/R/aaa_models.R b/R/aaa_models.R index 468bc1478..a53cb2f69 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -1,5 +1,28 @@ +# Initialize model environment + +# ------------------------------------------------------------------------------ + +## Rules about model-related information + +### Definitions: + +# - the model is the model type (e.g. "rand_forest", "linear_reg", etc) +# - the model's mode is the species of model such as "classification" or "regression" +# - the engines are within a model and mode and describe the method/implementation +# of the model in question. These are often R package names. + +### The package dependencies are model- and engine-specific. They are used across modes + +### The `fit` information is a list of data that is needed to fit the model. This +### information is specific to an engine and mode. + +### The `predict` information is also list of data that is needed to make some sort +### of prediction on the model object. The possible types are contained in `pred_types` +### and this information is specific to the engine, mode, and type (although there +### are no types across different modes). + +# ------------------------------------------------------------------------------ -# initialize model environment parsnip <- rlang::new_environment() parsnip$models <- NULL @@ -177,17 +200,18 @@ check_pkg_val <- function(x) { call. = FALSE) invisible(NULL) } + # ------------------------------------------------------------------------------ #' @export -register_new_model <- function(mod) { +set_new_model <- function(mod) { check_mod_val(mod, new = TRUE) current <- get_model_env() current$models <- c(current$models, mod) current[[mod]] <- dplyr::tibble(engine = character(0), mode = character(0)) - current[[paste0(mod, "_pkg")]] <- character(0) + current[[paste0(mod, "_pkgs")]] <- dplyr::tibble(engine = character(0), pkg = list()) current[[paste0(mod, "_modes")]] <- "unknown" current[[paste0(mod, "_args")]] <- dplyr::tibble( @@ -216,7 +240,7 @@ register_new_model <- function(mod) { # ------------------------------------------------------------------------------ #' @export -register_model_mode <- function(mod, mode) { +set_model_mode <- function(mod, mode) { check_mod_val(mod, existance = TRUE) check_mode_val(mode) @@ -234,7 +258,7 @@ register_model_mode <- function(mod, mode) { # ------------------------------------------------------------------------------ #' @export -register_model_engine <- function(mod, mode, eng) { +set_model_engine <- function(mod, mode, eng) { check_mod_val(mod, existance = TRUE) check_mode_val(mode) check_mode_val(eng) @@ -257,7 +281,7 @@ register_model_engine <- function(mod, mode, eng) { # ------------------------------------------------------------------------------ #' @export -register_model_arg <- function(mod, eng, val, original, func) { +set_model_arg <- function(mod, eng, val, original, func) { check_mod_val(mod, existance = TRUE) check_arg_val(val) check_arg_val(original) @@ -290,22 +314,61 @@ register_model_arg <- function(mod, eng, val, original, func) { # ------------------------------------------------------------------------------ #' @export -register_dependency <- function(mod, pkg) { +set_dependency <- function(mod, eng, pkg) { check_mod_val(mod, existance = TRUE) check_pkg_val(pkg) current <- get_model_env() + model_info <- current[[mod]] + pkg_info <- current[[paste0(mod, "_pkgs")]] + + has_engine <- + model_info %>% + dplyr::distinct(engine) %>% + dplyr::filter(engine == eng) %>% + nrow() + if (has_engine != 1) { + stop("The engine '", eng, "' has not been registered for model '", + mod, "'. ", call. = FALSE) + } - current[[paste0(mod, "_pkg")]] <- - unique(c(current[[paste0(mod, "_pkg")]], pkg)) + existing_pkgs <- + pkg_info %>% + dplyr::filter(engine == eng) + + if (nrow(existing_pkgs) == 0) { + pkg_info <- + pkg_info %>% + dplyr::bind_rows(tibble(engine = eng, pkg = list(pkg))) + + } else { + old_pkgs <- existing_pkgs + existing_pkgs$pkg[[1]] <- c(pkg, existing_pkgs$pkg[[1]]) + pkg_info <- + pkg_info %>% + dplyr::filter(engine != eng) %>% + dplyr::bind_rows(existing_pkgs) + } + current[[paste0(mod, "_pkgs")]] <- pkg_info invisible(NULL) } +#' @export +get_dependency <- function(mod) { + check_mod_val(mod, existance = TRUE) + pkg_name <- paste0(mod, "_pkgs") + if (!any(pkg_name != rlang::env_names(get_model_env()))) { + stop("`", mod, "` does not have a dependency list in parsnip.", call. = FALSE) + } + rlang::env_get(get_model_env(), pkg_name) +} + + # ------------------------------------------------------------------------------ #' @export -register_fit <- function(mod, mode, eng, value) { +set_fit <- function(mod, mode, eng, value) { check_mod_val(mod, existance = TRUE) check_mode_val(mode) check_engine_val(eng) @@ -320,7 +383,7 @@ register_fit <- function(mod, mode, eng, value) { dplyr::filter(engine == eng & mode == !!mode) %>% nrow() if (has_engine != 1) { - stop("The combination of engine '", eng, "' and mode '", + stop("set_fit The combination of engine '", eng, "' and mode '", mode, "' has not been registered for model '", mod, "'. ", call. = FALSE) } @@ -353,11 +416,20 @@ register_fit <- function(mod, mode, eng, value) { invisible(NULL) } +#' @export +get_fit <- function(mod) { + check_mod_val(mod, existance = TRUE) + fit_name <- paste0(mod, "_fit") + if (!any(fit_name != rlang::env_names(get_model_env()))) { + stop("`", mod, "` does not have a `fit` method in parsnip.", call. = FALSE) + } + rlang::env_get(get_model_env(), fit_name) +} # ------------------------------------------------------------------------------ #' @export -register_pred <- function(mod, mode, eng, type, value) { +set_pred <- function(mod, mode, eng, type, value) { check_mod_val(mod, existance = TRUE) check_mode_val(mode) check_engine_val(eng) @@ -406,6 +478,21 @@ register_pred <- function(mod, mode, eng, type, value) { invisible(NULL) } +#' @export +get_pred_type <- function(mod, type) { + check_mod_val(mod, existance = TRUE) + pred_name <- paste0(mod, "_predict") + if (!any(pred_name != rlang::env_names(get_model_env()))) { + stop("`", mod, "` does not have any `pred` methods in parsnip.", call. = FALSE) + } + all_preds <- rlang::env_get(get_model_env(), pred_name) + if (!any(all_preds$type == type)) { + stop("`", mod, "` does not have any `", type, + "` prediction methods in parsnip.", call. = FALSE) + } + dplyr::filter(all_preds, type == !!type) +} + # ------------------------------------------------------------------------------ #' @export diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index c96597a46..cf70e28ac 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -87,7 +87,7 @@ ranger_confint <- function(object, new_data, ...) { res <- ranger_class_confint(object, new_data, ...) } else { stop("Cannot compute confidence intervals for a ranger forest ", - "of type ", object$fit$forest$treetype, ".", call. = FALSE) + "of type ", object$fit$forest$treetype, ".", call. = FALSE) } } res @@ -283,34 +283,35 @@ rand_forest_spark_data <- # ------------------------------------------------------------------------------ -register_new_model("rand_forest") -register_model_mode("rand_forest", "classification") -register_model_mode("rand_forest", "regression") +set_new_model("rand_forest") -register_model_engine("rand_forest", "classification", "randomForest") -register_model_engine("rand_forest", "classification", "ranger") -register_model_engine("rand_forest", "classification", "spark") +set_model_mode("rand_forest", "classification") +set_model_mode("rand_forest", "regression") -register_model_engine("rand_forest", "regression", "randomForest") -register_model_engine("rand_forest", "regression", "ranger") -register_model_engine("rand_forest", "regression", "spark") +set_model_engine("rand_forest", "classification", "randomForest") +set_model_engine("rand_forest", "classification", "ranger") +set_model_engine("rand_forest", "classification", "spark") -register_model_arg( +set_model_engine("rand_forest", "regression", "randomForest") +set_model_engine("rand_forest", "regression", "ranger") +set_model_engine("rand_forest", "regression", "spark") + +set_model_arg( mod = "rand_forest", eng = "randomForest", val = "mtry", original = "mtry", func = list(pkg = "dials", fun = "mtry") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "randomForest", val = "trees", original = "ntree", func = list(pkg = "dials", fun = "trees") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "randomForest", val = "min_n", @@ -318,21 +319,21 @@ register_model_arg( func = list(pkg = "dials", fun = "min_n") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "ranger", val = "mtry", original = "mtry", func = list(pkg = "dials", fun = "mtry") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "ranger", val = "trees", original = "num.trees", func = list(pkg = "dials", fun = "trees") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "ranger", val = "min_n", @@ -340,21 +341,21 @@ register_model_arg( func = list(pkg = "dials", fun = "min_n") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "spark", val = "mtry", original = "feature_subset_strategy", func = list(pkg = "dials", fun = "mtry") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "spark", val = "trees", original = "num_trees", func = list(pkg = "dials", fun = "trees") ) -register_model_arg( +set_model_arg( mod = "rand_forest", eng = "spark", val = "min_n", @@ -362,23 +363,24 @@ register_model_arg( func = list(pkg = "dials", fun = "min_n") ) -register_fit(mod = "rand_forest", eng = "ranger", mode = "classification", - value = rand_forest_ranger_data$fit) +set_fit(mod = "rand_forest", eng = "ranger", mode = "classification", + value = rand_forest_ranger_data$fit) + +set_fit(mod = "rand_forest", eng = "ranger", mode = "regression", + value = rand_forest_ranger_data$fit) -register_fit(mod = "rand_forest", eng = "ranger", mode = "regression", - value = rand_forest_ranger_data$fit) +set_pred(mod = "rand_forest", eng = "ranger", mode = "classification", + type = "class", value = parsnip:::rand_forest_ranger_data$class) -register_pred(mod = "rand_forest", eng = "ranger", mode = "classification", - type = "class", value = parsnip:::rand_forest_ranger_data$class) +set_pred(mod = "rand_forest", eng = "ranger", mode = "classification", + type = "prob", value = parsnip:::rand_forest_ranger_data$classprob) -register_pred(mod = "rand_forest", eng = "ranger", mode = "classification", - type = "prob", value = parsnip:::rand_forest_ranger_data$classprob) +set_pred(mod = "rand_forest", eng = "ranger", mode = "classification", + type = "raw", value = parsnip:::rand_forest_ranger_data$raw) -register_pred(mod = "rand_forest", eng = "ranger", mode = "classification", - type = "raw", value = parsnip:::rand_forest_ranger_data$raw) +set_pred(mod = "rand_forest", eng = "ranger", mode = "regression", + type = "numeric", value = parsnip:::rand_forest_ranger_data$numeric) -register_pred(mod = "rand_forest", eng = "ranger", mode = "regression", - type = "numeric", value = parsnip:::rand_forest_ranger_data$numeric) +set_pred(mod = "rand_forest", eng = "ranger", mode = "regression", + type = "raw", value = parsnip:::rand_forest_ranger_data$raw) -register_pred(mod = "rand_forest", eng = "ranger", mode = "regression", - type = "raw", value = parsnip:::rand_forest_ranger_data$raw) From a4df557169369240d07082d42a184717d0b59a3b Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 21 May 2019 11:20:46 -0400 Subject: [PATCH 08/64] Added submodel info for issue #167 --- NAMESPACE | 1 + R/aaa_models.R | 18 +++++++++++++++--- R/rand_forest_data.R | 27 ++++++++++++++++++--------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 98b71e662..7a2834644 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -87,6 +87,7 @@ export(check_mod_val) export(check_mode_val) export(check_pkg_val) export(check_pred_info) +export(check_submodels_val) export(decision_tree) export(fit) export(fit.model_spec) diff --git a/R/aaa_models.R b/R/aaa_models.R index a53cb2f69..2851da8af 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -92,6 +92,14 @@ check_arg_val <- function(arg) { invisible(NULL) } +#' @export +check_submodels_val <- function(x) { + if (!is.logical(x) || length(x) != 1) { + stop("The `submodels` argument should be a single logical.", call. = FALSE) + } + invisible(NULL) +} + #' @export check_func_val <- function(func) { msg <- @@ -281,11 +289,12 @@ set_model_engine <- function(mod, mode, eng) { # ------------------------------------------------------------------------------ #' @export -set_model_arg <- function(mod, eng, val, original, func) { +set_model_arg <- function(mod, eng, val, original, func, submodels) { check_mod_val(mod, existance = TRUE) check_arg_val(val) check_arg_val(original) check_func_val(func) + check_submodels_val(submodels) current <- get_model_env() old_args <- current[[paste0(mod, "_args")]] @@ -295,15 +304,18 @@ set_model_arg <- function(mod, eng, val, original, func) { engine = eng, parsnip = val, original = original, - func = list(func) + func = list(func), + submodels = submodels ) + # TODO cant currently use `distinct()` on a list column. + # Use `vctrs::vctrs_duplicated()` instead updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE) if (inherits(updated, "try-error")) { stop("An error occured when adding the new argument.", call. = FALSE) } - updated <- dplyr::distinct(updated, engine, parsnip, original) + updated <- dplyr::distinct(updated, engine, parsnip, original, submodels) current[[paste0(mod, "_args")]] <- updated diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index cf70e28ac..3d6b50d69 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -302,21 +302,24 @@ set_model_arg( eng = "randomForest", val = "mtry", original = "mtry", - func = list(pkg = "dials", fun = "mtry") + func = list(pkg = "dials", fun = "mtry"), + submodels = FALSE ) set_model_arg( mod = "rand_forest", eng = "randomForest", val = "trees", original = "ntree", - func = list(pkg = "dials", fun = "trees") + func = list(pkg = "dials", fun = "trees"), + submodels = FALSE ) set_model_arg( mod = "rand_forest", eng = "randomForest", val = "min_n", original = "nodesize", - func = list(pkg = "dials", fun = "min_n") + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE ) set_model_arg( @@ -324,21 +327,24 @@ set_model_arg( eng = "ranger", val = "mtry", original = "mtry", - func = list(pkg = "dials", fun = "mtry") + func = list(pkg = "dials", fun = "mtry"), + submodels = FALSE ) set_model_arg( mod = "rand_forest", eng = "ranger", val = "trees", original = "num.trees", - func = list(pkg = "dials", fun = "trees") + func = list(pkg = "dials", fun = "trees"), + submodels = FALSE ) set_model_arg( mod = "rand_forest", eng = "ranger", val = "min_n", original = "min.node.size", - func = list(pkg = "dials", fun = "min_n") + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE ) set_model_arg( @@ -346,21 +352,24 @@ set_model_arg( eng = "spark", val = "mtry", original = "feature_subset_strategy", - func = list(pkg = "dials", fun = "mtry") + func = list(pkg = "dials", fun = "mtry"), + submodels = FALSE ) set_model_arg( mod = "rand_forest", eng = "spark", val = "trees", original = "num_trees", - func = list(pkg = "dials", fun = "trees") + func = list(pkg = "dials", fun = "trees"), + submodels = FALSE ) set_model_arg( mod = "rand_forest", eng = "spark", val = "min_n", original = "min_instances_per_node", - func = list(pkg = "dials", fun = "min_n") + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE ) set_fit(mod = "rand_forest", eng = "ranger", mode = "classification", From eda3e97a4f753aed12ab9c3c1b9cb338728e0a1e Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 09:28:24 -0400 Subject: [PATCH 09/64] more model conversions --- NAMESPACE | 134 ---------- R/boost_tree_data.R | 539 ++++++++++++++++++++++++++------------ R/decision_tree_data.R | 433 ++++++++++++++++++++----------- R/linear_reg_data.R | 571 +++++++++++++++++++++++------------------ R/mars_data.R | 208 ++++++++++----- R/rand_forest_data.R | 559 ++++++++++++++++++++++++---------------- R/svm_rbf_data.R | 195 +++++++++----- 7 files changed, 1608 insertions(+), 1031 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 7a2834644..d3557b196 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,139 +1,5 @@ # Generated by roxygen2: do not edit by hand -S3method(fit,model_spec) -S3method(fit_xy,model_spec) -S3method(multi_predict,"_C5.0") -S3method(multi_predict,"_earth") -S3method(multi_predict,"_elnet") -S3method(multi_predict,"_lognet") -S3method(multi_predict,"_multnet") -S3method(multi_predict,"_xgb.Booster") -S3method(multi_predict,default) -S3method(nullmodel,default) -S3method(predict,"_elnet") -S3method(predict,"_lognet") -S3method(predict,"_multnet") -S3method(predict,model_fit) -S3method(predict,model_spec) -S3method(predict,nullmodel) -S3method(predict_class,"_lognet") -S3method(predict_class,"_multnet") -S3method(predict_classprob,"_lognet") -S3method(predict_classprob,"_multnet") -S3method(predict_numeric,"_elnet") -S3method(predict_raw,"_elnet") -S3method(predict_raw,"_lognet") -S3method(predict_raw,"_multnet") -S3method(print,boost_tree) -S3method(print,decision_tree) -S3method(print,fit_control) -S3method(print,linear_reg) -S3method(print,logistic_reg) -S3method(print,mars) -S3method(print,mlp) -S3method(print,model_fit) -S3method(print,model_spec) -S3method(print,multinom_reg) -S3method(print,nearest_neighbor) -S3method(print,nullmodel) -S3method(print,rand_forest) -S3method(print,surv_reg) -S3method(print,svm_poly) -S3method(print,svm_rbf) -S3method(translate,boost_tree) -S3method(translate,decision_tree) -S3method(translate,default) -S3method(translate,mars) -S3method(translate,mlp) -S3method(translate,rand_forest) -S3method(translate,surv_reg) -S3method(translate,svm_poly) -S3method(translate,svm_rbf) -S3method(type_sum,model_fit) -S3method(type_sum,model_spec) -S3method(update,boost_tree) -S3method(update,decision_tree) -S3method(update,linear_reg) -S3method(update,logistic_reg) -S3method(update,mars) -S3method(update,mlp) -S3method(update,multinom_reg) -S3method(update,nearest_neighbor) -S3method(update,rand_forest) -S3method(update,surv_reg) -S3method(update,svm_poly) -S3method(update,svm_rbf) -S3method(varying_args,model_spec) -S3method(varying_args,recipe) -S3method(varying_args,step) -export("%>%") -export(.cols) -export(.dat) -export(.facts) -export(.lvls) -export(.obs) -export(.preds) -export(.x) -export(.y) -export(C5.0_train) -export(add_rowindex) -export(boost_tree) -export(check_arg_val) -export(check_empty_ellipse) -export(check_engine_val) -export(check_fit_info) -export(check_func_val) -export(check_mod_val) -export(check_mode_val) -export(check_pkg_val) -export(check_pred_info) -export(check_submodels_val) -export(decision_tree) -export(fit) -export(fit.model_spec) -export(fit_control) -export(fit_xy) -export(fit_xy.model_spec) -export(get_dependency) -export(get_fit) -export(get_model_env) -export(get_pred_type) -export(keras_mlp) -export(linear_reg) -export(logistic_reg) -export(make_classes) -export(mars) -export(mlp) -export(model_printer) -export(multi_predict) -export(multinom_reg) -export(nearest_neighbor) -export(null_model) -export(nullmodel) -export(predict.model_fit) -export(rand_forest) -export(rpart_train) -export(set_args) -export(set_dependency) -export(set_engine) -export(set_fit) -export(set_mode) -export(set_model_arg) -export(set_model_engine) -export(set_model_mode) -export(set_new_model) -export(set_pred) -export(show_call) -export(show_model_info) -export(surv_reg) -export(svm_poly) -export(svm_rbf) -export(tidy.model_fit) -export(translate) -export(validate_model) -export(varying) -export(varying_args) -export(xgb_train) importFrom(dplyr,arrange) importFrom(dplyr,as_tibble) importFrom(dplyr,bind_cols) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 559698511..f3688717b 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -1,175 +1,386 @@ +set_new_model("boost_tree") -boost_tree_arg_key <- data.frame( - xgboost = c("max_depth", "nrounds", "eta", "colsample_bytree", "min_child_weight", "gamma", "subsample"), - C5.0 = c( NA, "trials", NA, NA, "minCases", NA, "sample"), - spark = c("max_depth", "max_iter", "step_size", "feature_subset_strategy", "min_instances_per_node", "min_info_gain", "subsampling_rate"), - stringsAsFactors = FALSE, - row.names = c("tree_depth", "trees", "learn_rate", "mtry", "min_n", "loss_reduction", "sample_size") +set_model_mode("boost_tree", "classification") +set_model_mode("boost_tree", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("boost_tree", "classification", "xgboost") +set_model_engine("boost_tree", "regression", "xgboost") +set_dependency("boost_tree", "xgboost", "xgboost") + +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "trees", + original = "nrounds", + func = list(pkg = "dials", fun = "trees"), + submodels = TRUE +) +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "learn_rate", + original = "eta", + func = list(pkg = "dials", fun = "learn_rate"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "mtry", + original = "colsample_bytree", + func = list(pkg = "dials", fun = "mtry"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "min_n", + original = "min_child_weight", + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "loss_reduction", + original = "gamma", + func = list(pkg = "dials", fun = "loss_reduction"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "xgboost", + val = "sample_size", + original = "subsample", + func = list(pkg = "dials", fun = "sample_size"), + submodels = FALSE +) + +set_fit( + mod = "boost_tree", + eng = "xgboost", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "xgb_train"), + defaults = list(nthread = 1, verbose = 0) + ) +) + +set_pred( + mod = "boost_tree", + eng = "xgboost", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "boost_tree", + eng = "xgboost", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) ) -boost_tree_modes <- c("classification", "regression", "unknown") +set_fit( + mod = "boost_tree", + eng = "xgboost", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "xgb_train"), + defaults = list(nthread = 1, verbose = 0) + ) +) + +set_pred( + mod = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + if (is.vector(x)) { + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + } else { + x <- object$lvl[apply(x, 1, which.max)] + } + x + }, + func = c(pkg = NULL, fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + if (is.vector(x)) { + x <- tibble(v1 = 1 - x, v2 = x) + } else { + x <- as_tibble(x) + } + colnames(x) <- object$lvl + x + }, + func = c(pkg = NULL, fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) -boost_tree_engines <- data.frame( - xgboost = rep(TRUE, 3), - C5.0 = c( TRUE, FALSE, TRUE), - spark = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_pred( + mod = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) ) # ------------------------------------------------------------------------------ -boost_tree_xgboost_data <- - list( - libs = "xgboost", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "xgb_train"), - defaults = - list( - nthread = 1, - verbose = 0 - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - if (is.vector(x)) { - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - } else { - x <- object$lvl[apply(x, 1, which.max)] - } - x - }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - if (is.vector(x)) { - x <- tibble(v1 = 1 - x, v2 = x) - } else { - x <- as_tibble(x) - } - colnames(x) <- object$lvl - x - }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -boost_tree_C5.0_data <- - list( - libs = "C50", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "parsnip", fun = "C5.0_train"), - defaults = list() - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = list( +set_model_engine("boost_tree", "classification", "C5.0") +set_dependency("boost_tree", "C5.0", "C50") + +set_model_arg( + mod = "boost_tree", + eng = "C5.0", + val = "trees", + original = "trials", + func = list(pkg = "dials", fun = "trees"), + submodels = TRUE +) +set_model_arg( + mod = "boost_tree", + eng = "C5.0", + val = "min_n", + original = "minCases", + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "C5.0", + val = "sample_size", + original = "sample", + func = list(pkg = "dials", fun = "sample_size"), + submodels = FALSE +) + +set_fit( + mod = "boost_tree", + eng = "C5.0", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "parsnip", fun = "C5.0_train"), + defaults = list() + ) +) + +set_pred( + mod = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "prob" ) - ) - ) - - -boost_tree_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "type"), - func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) ) +) + +set_pred( + mod = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("boost_tree", "classification", "spark") +set_model_engine("boost_tree", "regression", "spark") +set_dependency("boost_tree", "spark", "sparklyr") + +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "trees", + original = "max_iter", + func = list(pkg = "dials", fun = "trees"), + submodels = TRUE +) +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "learn_rate", + original = "step_size", + func = list(pkg = "dials", fun = "learn_rate"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "mtry", + original = "feature_subset_strategy", + func = list(pkg = "dials", fun = "mtry"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "min_info_gain", + original = "gamma", + func = list(pkg = "dials", fun = "loss_reduction"), + submodels = FALSE +) +set_model_arg( + mod = "boost_tree", + eng = "spark", + val = "sample_size", + original = "subsampling_rate", + func = list(pkg = "dials", fun = "sample_size"), + submodels = FALSE +) + +set_fit( + mod = "boost_tree", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + mod = "boost_tree", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + mod = "boost_tree", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + mod = "boost_tree", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + mod = "boost_tree", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index b4a599063..cd443312a 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -1,160 +1,297 @@ +set_new_model("decision_tree") -decision_tree_arg_key <- data.frame( - rpart = c( "maxdepth", "minsplit", "cp"), - C5.0 = c( NA, "minCases", NA), - spark = c("max_depth", "min_instances_per_node", NA), - stringsAsFactors = FALSE, - row.names = c("tree_depth", "min_n", "cost_complexity") +set_model_mode("decision_tree", "classification") +set_model_mode("decision_tree", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "rpart") +set_model_engine("decision_tree", "regression", "rpart") +set_dependency("decision_tree", "rpart", "rpart") + +set_model_arg( + mod = "decision_tree", + eng = "rpart", + val = "tree_depth", + original = "maxdepth", + func = list(pkg = "dials", fun = "tree_depth"), + submodels = FALSE ) -decision_tree_modes <- c("classification", "regression", "unknown") +set_model_arg( + mod = "decision_tree", + eng = "rpart", + val = "min_n", + original = "minsplit", + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE +) -decision_tree_engines <- data.frame( - rpart = rep(TRUE, 3), - C5.0 = c(TRUE, FALSE, TRUE), - spark = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_model_arg( + mod = "decision_tree", + eng = "rpart", + val = "cost_complexity", + original = "cp", + func = list(pkg = "dials", fun = "cost_complexity"), + submodels = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + mod = "decision_tree", + eng = "rpart", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rpart", fun = "rpart"), + defaults = list() + ) +) + +set_fit( + mod = "decision_tree", + eng = "rpart", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rpart", fun = "rpart"), + defaults = list() + ) +) + +set_pred( + mod = "decision_tree", + eng = "rpart", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "decision_tree", + eng = "rpart", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) -decision_tree_rpart_data <- - list( - libs = "rpart", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rpart", fun = "rpart"), - defaults = - list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -decision_tree_C5.0_data <- - list( - libs = "C50", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "parsnip", fun = "C5.0_train"), - defaults = list(trials = 1) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = list( +set_pred( + mod = "decision_tree", + eng = "rpart", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = NULL, fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "class" ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = list( + ) +) + +set_pred( + mod = "decision_tree", + eng = "rpart", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(pkg = NULL, fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "decision_tree", + eng = "rpart", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "C5.0") +set_dependency("decision_tree", "C5.0", "C5.0") + +set_model_arg( + mod = "decision_tree", + eng = "C5.0", + val = "min_n", + original = "minsplit", + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE +) + +set_fit( + mod = "decision_tree", + eng = "C5.0", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "parsnip", fun = "C5.0_train"), + defaults = list(trials = 1) + ) +) + +set_pred( + mod = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + + +set_pred( + mod = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "prob" ) - ) - ) - - -decision_tree_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula"), - func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) ) +) + + +set_pred( + mod = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "spark") +set_model_engine("decision_tree", "regression", "spark") +set_dependency("decision_tree", "spark", "spark") + +set_model_arg( + mod = "decision_tree", + eng = "spark", + val = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + submodels = FALSE +) + +set_model_arg( + mod = "decision_tree", + eng = "spark", + val = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + submodels = FALSE +) + +set_fit( + mod = "decision_tree", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + defaults = + list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + mod = "decision_tree", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + defaults = + list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + mod = "decision_tree", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + mod = "decision_tree", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + mod = "decision_tree", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 6225d1023..27cfb71f8 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -1,265 +1,348 @@ +set_new_model("linear_reg") -linear_reg_arg_key <- data.frame( - lm = c( NA, NA), - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - stan = c( NA, NA), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("linear_reg", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "lm") +set_dependency("linear_reg", "lm", "stats") + +set_fit( + mod = "linear_reg", + eng = "lm", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "stats", fun = "lm"), + defaults = list() + ) ) -linear_reg_modes <- "regression" +set_pred( + mod = "linear_reg", + eng = "lm", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ) +) -linear_reg_engines <- data.frame( - lm = TRUE, - glmnet = TRUE, - spark = TRUE, - stan = TRUE, - keras = TRUE, - row.names = c("regression") +set_pred( + mod = "linear_reg", + eng = "lm", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + tibble::as_tibble(results) %>% + dplyr::select(-fit) %>% + setNames(c(".pred_lower", ".pred_upper")) + }, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "confidence", + level = expr(level), + type = "response" + ) + ) +) +set_pred( + mod = "linear_reg", + eng = "lm", + mode = "regression", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + tibble::as_tibble(results) %>% + dplyr::select(-fit) %>% + setNames(c(".pred_lower", ".pred_upper")) + }, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "prediction", + level = expr(level), + type = "response" + ) + ) ) +set_pred( + mod = "linear_reg", + eng = "lm", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) # ------------------------------------------------------------------------------ -linear_reg_lm_data <- - list( - libs = "stats", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "stats", fun = "lm"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - tibble::as_tibble(results) %>% - dplyr::select(-fit) %>% - setNames(c(".pred_lower", ".pred_upper")) - }, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "confidence", - level = expr(level), - type = "response" - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - tibble::as_tibble(results) %>% - dplyr::select(-fit) %>% - setNames(c(".pred_lower", ".pred_upper")) - }, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "prediction", - level = expr(level), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ) +set_model_engine("linear_reg", "regression", "glmnet") +set_dependency("linear_reg", "glmnet", "glmnet") + +set_model_arg( + mod = "linear_reg", + eng = "glmnet", + val = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + submodels = TRUE +) + +set_model_arg( + mod = "linear_reg", + eng = "glmnet", + val = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + submodels = FALSE +) + +set_fit( + mod = "linear_reg", + eng = "glmnet", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "gaussian") ) +) -# Note: For glmnet, you will need to make model-specific predict methods. -# See linear_reg.R -linear_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "gaussian" - ) - ), - numeric = list( - pre = NULL, - post = organize_glmnet_pred, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data)), - type = "response", - s = expr(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data)) - ) - ) +set_pred( + mod = "linear_reg", + eng = "glmnet", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = organize_glmnet_pred, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newx = expr(as.matrix(new_data)), + type = "response", + s = expr(object$spec$args$penalty) + ) ) +) -linear_reg_stan_data <- - list( - libs = "rstanarm", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rstanarm", fun = "stan_glm"), - defaults = - list( - family = expr(stats::gaussian) - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - res <- - tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level, - lower = FALSE - ), - ) - if(object$spec$method$confint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_linpred"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - transform = TRUE, - seed = expr(sample.int(10^5, 1)) - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - res <- - tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level, - lower = FALSE - ), - ) - if(object$spec$method$predint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ) +set_pred( + mod = "linear_reg", + eng = "glmnet", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = expr(object$fit), + newx = expr(as.matrix(new_data))) ) +) -#' @importFrom dplyr select rename -linear_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_linear_regression") - ), - numeric = list( - pre = NULL, - post = function(results, object) { - results <- dplyr::rename(results, pred = prediction) - results <- dplyr::select(results, pred) - results - }, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = expr(object$fit), - dataset = expr(new_data) +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "stan") +set_dependency("linear_reg", "stan", "rstanarm") + +set_fit( + mod = "linear_reg", + eng = "stan", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rstanarm", fun = "stan_glm"), + defaults = list(family = expr(stats::gaussian)) + ) +) + +set_pred( + mod = "linear_reg", + eng = "stan", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) + +set_pred( + mod = "linear_reg", + eng = "stan", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res <- + tibble( + .pred_lower = + convert_stan_interval( + results, + level = object$spec$method$confint$extras$level + ), + .pred_upper = + convert_stan_interval( + results, + level = object$spec$method$confint$extras$level, + lower = FALSE + ), ) - ) + if(object$spec$method$confint$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_linpred"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + transform = TRUE, + seed = expr(sample.int(10^5, 1)) + ) ) +) -linear_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) +set_pred( + mod = "linear_reg", + eng = "stan", + mode = "regression", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + res <- + tibble( + .pred_lower = + convert_stan_interval( + results, + level = object$spec$method$predint$extras$level + ), + .pred_upper = + convert_stan_interval( + results, + level = object$spec$method$predint$extras$level, + lower = FALSE + ), ) - ) + if(object$spec$method$predint$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) + +set_pred( + mod = "linear_reg", + eng = "stan", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "spark") +set_dependency("linear_reg", "spark", "sparklyr") + +set_fit( + mod = "linear_reg", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_linear_regression") + ) +) + +set_pred( + mod = "linear_reg", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) { + results <- dplyr::rename(results, pred = prediction) + results <- dplyr::select(results, pred) + results + }, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = expr(object$fit), dataset = expr(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + + +set_model_engine("linear_reg", "regression", "keras") +set_dependency("linear_reg", "keras", c("keras", "magrittr")) + +set_fit( + mod = "linear_reg", + eng = "keras", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") ) +) + +set_pred( + mod = "linear_reg", + eng = "keras", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) + ) +) diff --git a/R/mars_data.R b/R/mars_data.R index 0c1076e68..99b5ffc22 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -14,64 +14,154 @@ mars_engines <- data.frame( # ------------------------------------------------------------------------------ -mars_earth_data <- - list( - libs = "earth", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "earth", fun = "earth"), - defaults = list(keepxy = TRUE) - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - x <- ifelse(x[,1] >= 0.5, object$lvl[2], object$lvl[1]) - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- x[,1] - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_new_model("mars") + +set_model_mode("mars", "classification") +set_model_mode("mars", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("mars", "classification", "earth") +set_model_engine("mars", "regression", "earth") +set_dependency("mars", "earth", "earth") + +set_model_arg( + mod = "mars", + eng = "earth", + val = "num_terms", + original = "nprune", + func = list(pkg = "dials", fun = "num_terms"), + submodels = FALSE +) +set_model_arg( + mod = "mars", + eng = "earth", + val = "prod_degree", + original = "degree", + func = list(pkg = "dials", fun = "prod_degree"), + submodels = FALSE +) +set_model_arg( + mod = "mars", + eng = "earth", + val = "prune_method", + original = "pmethod", + func = list(pkg = "dials", fun = "prune_method"), + submodels = FALSE +) + +set_fit( + mod = "mars", + eng = "earth", + mode = "regression", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "earth", fun = "earth"), + defaults = list(keepxy = TRUE) ) +) +set_fit( + mod = "mars", + eng = "earth", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "earth", fun = "earth"), + defaults = list(keepxy = TRUE) + ) +) + +set_pred( + mod = "mars", + eng = "earth", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + mod = "mars", + eng = "earth", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + mod = "mars", + eng = "earth", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- ifelse(x[, 1] >= 0.5, object$lvl[2], object$lvl[1]) + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + mod = "mars", + eng = "earth", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- x[, 1] + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + mod = "mars", + eng = "earth", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 3d6b50d69..132f5047b 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -95,211 +95,21 @@ ranger_confint <- function(object, new_data, ...) { # ------------------------------------------------------------------------------ - -rand_forest_ranger_data <- - list( - libs = "ranger", - fit = list( - interface = "formula", - protect = c("formula", "data", "case.weights"), - func = c(pkg = "ranger", fun = "ranger"), - defaults = - list( - num.threads = 1, - verbose = FALSE, - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = function(results, object) results$predictions, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - class = list( - pre = NULL, - post = ranger_class_pred, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - classprob = list( - pre = function(x, object) { - if (object$fit$forest$treetype != "Probability estimation") - stop("`ranger` model does not appear to use class probabilities. Was ", - "the model fit with `probability = TRUE`?", - call. = FALSE) - x - }, - post = function(x, object) { - x <- x$prediction - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - raw = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ), - confint = list( - pre = NULL, - post = NULL, - func = c(fun = "ranger_confint"), - args = - list( - object = quote(object), - new_data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ) - ) - -rand_forest_randomForest_data <- - list( - libs = "randomForest", - fit = list( - interface = "data.frame", - protect = c("x", "y"), - func = c(pkg = "randomForest", fun = "randomForest"), - defaults = - list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(as.data.frame(x)) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -rand_forest_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "type"), - func = c(pkg = "sparklyr", fun = "ml_random_forest"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) - ) - - -# ------------------------------------------------------------------------------ - - set_new_model("rand_forest") set_model_mode("rand_forest", "classification") set_model_mode("rand_forest", "regression") -set_model_engine("rand_forest", "classification", "randomForest") -set_model_engine("rand_forest", "classification", "ranger") -set_model_engine("rand_forest", "classification", "spark") +# ------------------------------------------------------------------------------ +# ranger components -set_model_engine("rand_forest", "regression", "randomForest") +set_model_engine("rand_forest", "classification", "ranger") set_model_engine("rand_forest", "regression", "ranger") -set_model_engine("rand_forest", "regression", "spark") +set_dependency("rand_forest", "ranger", "ranger") set_model_arg( mod = "rand_forest", - eng = "randomForest", + eng = "ranger", val = "mtry", original = "mtry", func = list(pkg = "dials", fun = "mtry"), @@ -307,24 +117,172 @@ set_model_arg( ) set_model_arg( mod = "rand_forest", - eng = "randomForest", + eng = "ranger", val = "trees", - original = "ntree", + original = "num.trees", func = list(pkg = "dials", fun = "trees"), submodels = FALSE ) set_model_arg( mod = "rand_forest", - eng = "randomForest", + eng = "ranger", val = "min_n", - original = "nodesize", + original = "min.node.size", func = list(pkg = "dials", fun = "min_n"), submodels = FALSE ) -set_model_arg( +set_fit( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "case.weights"), + func = c(pkg = "ranger", fun = "ranger"), + defaults = + list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +set_fit( + mod = "rand_forest", + eng = "ranger", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "case.weights"), + func = c(pkg = "ranger", fun = "ranger"), + defaults = + list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = ranger_class_pred, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + type = "prob", + value = list( + pre = function(x, object) { + if (object$fit$forest$treetype != "Probability estimation") + stop( + "`ranger` model does not appear to use class probabilities. Was ", + "the model fit with `probability = TRUE`?", + call. = FALSE + ) + x + }, + post = function(x, object) { + x <- x$prediction + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) + results$predictions, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + +set_pred( mod = "rand_forest", eng = "ranger", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +# ------------------------------------------------------------------------------ +# randomForest components + +set_model_engine("rand_forest", "classification", "randomForest") +set_model_engine("rand_forest", "regression", "randomForest") +set_dependency("rand_forest", "randomForest", "randomForest") + +set_model_arg( + mod = "rand_forest", + eng = "randomForest", val = "mtry", original = "mtry", func = list(pkg = "dials", fun = "mtry"), @@ -332,21 +290,133 @@ set_model_arg( ) set_model_arg( mod = "rand_forest", - eng = "ranger", + eng = "randomForest", val = "trees", - original = "num.trees", + original = "ntree", func = list(pkg = "dials", fun = "trees"), submodels = FALSE ) set_model_arg( mod = "rand_forest", - eng = "ranger", + eng = "randomForest", val = "min_n", - original = "min.node.size", + original = "nodesize", func = list(pkg = "dials", fun = "min_n"), submodels = FALSE ) +set_fit( + mod = "rand_forest", + eng = "randomForest", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y"), + func = c(pkg = "randomForest", fun = "randomForest"), + defaults = + list() + ) +) + +set_fit( + mod = "rand_forest", + eng = "randomForest", + mode = "regression", + value = list( + interface = "data.frame", + protect = c("x", "y"), + func = c(pkg = "randomForest", fun = "randomForest"), + defaults = + list() + ) +) + +set_pred( + mod = "rand_forest", + eng = "randomForest", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + mod = "rand_forest", + eng = "randomForest", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + mod = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + + +set_pred( + mod = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(as.data.frame(x)) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + mod = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ +# spark components + +set_model_engine("rand_forest", "classification", "spark") +set_model_engine("rand_forest", "regression", "spark") +set_dependency("rand_forest", "spark", "sparklyr") + set_model_arg( mod = "rand_forest", eng = "spark", @@ -372,24 +442,71 @@ set_model_arg( submodels = FALSE ) -set_fit(mod = "rand_forest", eng = "ranger", mode = "classification", - value = rand_forest_ranger_data$fit) - -set_fit(mod = "rand_forest", eng = "ranger", mode = "regression", - value = rand_forest_ranger_data$fit) - -set_pred(mod = "rand_forest", eng = "ranger", mode = "classification", - type = "class", value = parsnip:::rand_forest_ranger_data$class) - -set_pred(mod = "rand_forest", eng = "ranger", mode = "classification", - type = "prob", value = parsnip:::rand_forest_ranger_data$classprob) +set_fit( + mod = "rand_forest", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_random_forest"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) -set_pred(mod = "rand_forest", eng = "ranger", mode = "classification", - type = "raw", value = parsnip:::rand_forest_ranger_data$raw) +set_fit( + mod = "rand_forest", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_random_forest"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) -set_pred(mod = "rand_forest", eng = "ranger", mode = "regression", - type = "numeric", value = parsnip:::rand_forest_ranger_data$numeric) +set_pred( + mod = "rand_forest", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) -set_pred(mod = "rand_forest", eng = "ranger", mode = "regression", - type = "raw", value = parsnip:::rand_forest_ranger_data$raw) +set_pred( + mod = "rand_forest", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) +set_pred( + mod = "rand_forest", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index fdb12727d..4a1203cf0 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -1,69 +1,142 @@ -svm_rbf_arg_key <- data.frame( - kernlab = c( "C", "sigma", "epsilon"), - row.names = c("cost", "rbf_sigma", "margin"), - stringsAsFactors = FALSE +set_new_model("svm_rbf") + +set_model_mode("svm_rbf", "classification") +set_model_mode("svm_rbf", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("svm_rbf", "classification", "kernlab") +set_model_engine("svm_rbf", "regression", "kernlab") +set_dependency("svm_rbf", "kernlab", "kernlab") + +set_model_arg( + mod = "svm_rbf", + eng = "kernlab", + val = "cost", + original = "C", + func = list(pkg = "dials", fun = "cost"), + submodels = FALSE ) -svm_rbf_modes <- c("classification", "regression", "unknown") +set_model_arg( + mod = "svm_rbf", + eng = "kernlab", + val = "rbf_sigma", + original = "sigma", + func = list(pkg = "dials", fun = "rbf_sigma"), + submodels = FALSE +) -svm_rbf_engines <- data.frame( - kernlab = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_model_arg( + mod = "svm_rbf", + eng = "kernlab", + val = "margin", + original = "epsilon", + func = list(pkg = "dials", fun = "margin"), + submodels = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + mod = "svm_rbf", + eng = "kernlab", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "rbfdot") + ) +) + +set_fit( + mod = "svm_rbf", + eng = "kernlab", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "rbfdot") + ) +) + +set_pred( + mod = "svm_rbf", + eng = "kernlab", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + mod = "svm_rbf", + eng = "kernlab", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) -svm_rbf_kernlab_data <- - list( - libs = "kernlab", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "kernlab", fun = "ksvm"), - defaults = list( - kernel = "rbfdot" +set_pred( + mod = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" ) - ), - numeric = list( - pre = NULL, - post = svm_reg_post, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(result, object) as_tibble(result), - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) - ), - raw = list( - pre = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) ) +) + +set_pred( + mod = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + From 33f63d198cc2ab3b165d790232369d67e6b7391d Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 09:44:42 -0400 Subject: [PATCH 10/64] more model conversions --- NAMESPACE | 134 ++++++++++++++++++++++++++ R/linear_reg_data.R | 6 +- R/mars_data.R | 15 --- R/surv_reg_data.R | 224 +++++++++++++++++++++++--------------------- R/svm_poly_data.R | 203 +++++++++++++++++++++++++++------------ 5 files changed, 397 insertions(+), 185 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index d3557b196..7a2834644 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,139 @@ # Generated by roxygen2: do not edit by hand +S3method(fit,model_spec) +S3method(fit_xy,model_spec) +S3method(multi_predict,"_C5.0") +S3method(multi_predict,"_earth") +S3method(multi_predict,"_elnet") +S3method(multi_predict,"_lognet") +S3method(multi_predict,"_multnet") +S3method(multi_predict,"_xgb.Booster") +S3method(multi_predict,default) +S3method(nullmodel,default) +S3method(predict,"_elnet") +S3method(predict,"_lognet") +S3method(predict,"_multnet") +S3method(predict,model_fit) +S3method(predict,model_spec) +S3method(predict,nullmodel) +S3method(predict_class,"_lognet") +S3method(predict_class,"_multnet") +S3method(predict_classprob,"_lognet") +S3method(predict_classprob,"_multnet") +S3method(predict_numeric,"_elnet") +S3method(predict_raw,"_elnet") +S3method(predict_raw,"_lognet") +S3method(predict_raw,"_multnet") +S3method(print,boost_tree) +S3method(print,decision_tree) +S3method(print,fit_control) +S3method(print,linear_reg) +S3method(print,logistic_reg) +S3method(print,mars) +S3method(print,mlp) +S3method(print,model_fit) +S3method(print,model_spec) +S3method(print,multinom_reg) +S3method(print,nearest_neighbor) +S3method(print,nullmodel) +S3method(print,rand_forest) +S3method(print,surv_reg) +S3method(print,svm_poly) +S3method(print,svm_rbf) +S3method(translate,boost_tree) +S3method(translate,decision_tree) +S3method(translate,default) +S3method(translate,mars) +S3method(translate,mlp) +S3method(translate,rand_forest) +S3method(translate,surv_reg) +S3method(translate,svm_poly) +S3method(translate,svm_rbf) +S3method(type_sum,model_fit) +S3method(type_sum,model_spec) +S3method(update,boost_tree) +S3method(update,decision_tree) +S3method(update,linear_reg) +S3method(update,logistic_reg) +S3method(update,mars) +S3method(update,mlp) +S3method(update,multinom_reg) +S3method(update,nearest_neighbor) +S3method(update,rand_forest) +S3method(update,surv_reg) +S3method(update,svm_poly) +S3method(update,svm_rbf) +S3method(varying_args,model_spec) +S3method(varying_args,recipe) +S3method(varying_args,step) +export("%>%") +export(.cols) +export(.dat) +export(.facts) +export(.lvls) +export(.obs) +export(.preds) +export(.x) +export(.y) +export(C5.0_train) +export(add_rowindex) +export(boost_tree) +export(check_arg_val) +export(check_empty_ellipse) +export(check_engine_val) +export(check_fit_info) +export(check_func_val) +export(check_mod_val) +export(check_mode_val) +export(check_pkg_val) +export(check_pred_info) +export(check_submodels_val) +export(decision_tree) +export(fit) +export(fit.model_spec) +export(fit_control) +export(fit_xy) +export(fit_xy.model_spec) +export(get_dependency) +export(get_fit) +export(get_model_env) +export(get_pred_type) +export(keras_mlp) +export(linear_reg) +export(logistic_reg) +export(make_classes) +export(mars) +export(mlp) +export(model_printer) +export(multi_predict) +export(multinom_reg) +export(nearest_neighbor) +export(null_model) +export(nullmodel) +export(predict.model_fit) +export(rand_forest) +export(rpart_train) +export(set_args) +export(set_dependency) +export(set_engine) +export(set_fit) +export(set_mode) +export(set_model_arg) +export(set_model_engine) +export(set_model_mode) +export(set_new_model) +export(set_pred) +export(show_call) +export(show_model_info) +export(surv_reg) +export(svm_poly) +export(svm_rbf) +export(tidy.model_fit) +export(translate) +export(validate_model) +export(varying) +export(varying_args) +export(xgb_train) importFrom(dplyr,arrange) importFrom(dplyr,as_tibble) importFrom(dplyr,bind_cols) diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 27cfb71f8..5e7e22180 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -294,7 +294,8 @@ set_fit( value = list( interface = "formula", protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_linear_regression") + func = c(pkg = "sparklyr", fun = "ml_linear_regression"), + defaults = list() ) ) @@ -319,7 +320,8 @@ set_pred( set_model_engine("linear_reg", "regression", "keras") -set_dependency("linear_reg", "keras", c("keras", "magrittr")) +set_dependency("linear_reg", "keras", "keras") +set_dependency("linear_reg", "keras", "magrittr") set_fit( mod = "linear_reg", diff --git a/R/mars_data.R b/R/mars_data.R index 99b5ffc22..f3b4ae589 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -1,19 +1,4 @@ -mars_arg_key <- data.frame( - earth = c( "nprune", "degree", "pmethod"), - stringsAsFactors = FALSE, - row.names = c("num_terms", "prod_degree", "prune_method") -) - -mars_modes <- c("classification", "regression", "unknown") - -mars_engines <- data.frame( - earth = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") -) - -# ------------------------------------------------------------------------------ - set_new_model("mars") set_model_mode("mars", "classification") diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 046b0a469..4d05aa298 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -1,120 +1,130 @@ -surv_reg_arg_key <- data.frame( - survreg = c("dist"), - flexsurv = c("dist"), - stringsAsFactors = FALSE, - row.names = c("dist") -) +set_new_model("surv_reg") +set_model_mode("surv_reg", "regression") -surv_reg_modes <- "regression" +# ------------------------------------------------------------------------------ -surv_reg_engines <- data.frame( - survreg = TRUE, - flexsurv = TRUE, - stringsAsFactors = TRUE, - row.names = c("regression") -) +set_model_engine("surv_reg", "regression", "flexsurv") +set_dependency("surv_reg", "flexsurv", "flexsurv") +set_dependency("surv_reg", "flexsurv", "survival") -# ------------------------------------------------------------------------------ +set_model_arg( + mod = "surv_reg", + eng = "flexsurv", + val = "dist", + original = "dist", + func = list(pkg = "dials", fun = "dist"), + submodels = FALSE +) -surv_reg_flexsurv_data <- - list( - libs = c("survival", "flexsurv"), - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "flexsurv", fun = "flexsurvreg"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = flexsurv_mean, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "mean" - ) - ), - quantile = list( - pre = NULL, - post = flexsurv_quant, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - quantiles = expr(quantile) - ) - ) +set_fit( + mod = "surv_reg", + eng = "flexsurv", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "flexsurv", fun = "flexsurvreg"), + defaults = list() ) +) -# ------------------------------------------------------------------------------ +set_pred( + mod = "surv_reg", + eng = "flexsurv", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = flexsurv_mean, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "mean" + ) + ) +) -surv_reg_survreg_data <- - list( - libs = c("survival"), - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "survival", fun = "survreg"), - defaults = list(model = TRUE) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) - ), - quantile = list( - pre = NULL, - post = survreg_quant, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - p = expr(quantile) - ) - ) +set_pred( + mod = "surv_reg", + eng = "flexsurv", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = flexsurv_quant, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + quantiles = expr(quantile) + ) ) +) # ------------------------------------------------------------------------------ -# surv_reg_stan_data <- -# list( -# libs = c("brms"), -# fit = list( -# interface = "formula", -# protect = c("formula", "data", "weights"), -# func = c(pkg = "brms", fun = "brm"), -# defaults = list( -# family = expr(brms::weibull()), -# seed = expr(sample.int(10^5, 1)) -# ) -# ), -# numeric = list( -# pre = NULL, -# post = function(results, object) { -# tibble::as_tibble(results) %>% -# dplyr::select(Estimate) %>% -# setNames(".pred") -# }, -# func = c(fun = "predict"), -# args = -# list( -# object = expr(object$fit), -# newdata = expr(new_data), -# type = "response" -# ) -# ) -# ) +set_model_engine("surv_reg", "regression", "survival") +set_dependency("surv_reg", "survival", "survival") +set_model_arg( + mod = "surv_reg", + eng = "survival", + val = "dist", + original = "dist", + func = list(pkg = "dials", fun = "dist"), + submodels = FALSE +) + +set_fit( + mod = "surv_reg", + eng = "survival", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "survival", fun = "survreg"), + defaults = list(model = TRUE) + ) +) + +set_pred( + mod = "surv_reg", + eng = "survival", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ) +) + +set_pred( + mod = "surv_reg", + eng = "survival", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = survreg_quant, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + p = expr(quantile) + ) + ) +) diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index 04b0bc55e..89e551c78 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -1,69 +1,150 @@ -svm_poly_arg_key <- data.frame( - kernlab = c( "C", "degree", "scale", "epsilon"), - row.names = c("cost", "degree", "scale_factor", "margin"), - stringsAsFactors = FALSE +set_new_model("svm_poly") + +set_model_mode("svm_poly", "classification") +set_model_mode("svm_poly", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("svm_poly", "classification", "kernlab") +set_model_engine("svm_poly", "regression", "kernlab") +set_dependency("svm_poly", "kernlab", "kernlab") + +set_model_arg( + mod = "svm_poly", + eng = "kernlab", + val = "cost", + original = "C", + func = list(pkg = "dials", fun = "cost"), + submodels = FALSE ) -svm_poly_modes <- c("classification", "regression", "unknown") +set_model_arg( + mod = "svm_poly", + eng = "kernlab", + val = "degree", + original = "degree", + func = list(pkg = "dials", fun = "degree"), + submodels = FALSE +) -svm_poly_engines <- data.frame( - kernlab = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_model_arg( + mod = "svm_poly", + eng = "kernlab", + val = "scale_factor", + original = "scale", + func = list(pkg = "dials", fun = "scale_factor"), + submodels = FALSE +) +set_model_arg( + mod = "svm_poly", + eng = "kernlab", + val = "margin", + original = "epsilon", + func = list(pkg = "dials", fun = "margin"), + submodels = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + mod = "svm_poly", + eng = "kernlab", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "polydot") + ) +) -svm_poly_kernlab_data <- - list( - libs = "kernlab", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "kernlab", fun = "ksvm"), - defaults = list( - kernel = "polydot" +set_fit( + mod = "svm_poly", + eng = "kernlab", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "polydot") + ) +) + +set_pred( + mod = "svm_poly", + eng = "kernlab", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" ) - ), - numeric = list( - pre = NULL, - post = svm_reg_post, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(result, object) as_tibble(result), - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) - ), - raw = list( - pre = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) ) +) + +set_pred( + mod = "svm_poly", + eng = "kernlab", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + mod = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + mod = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ) +) + +set_pred( + mod = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + From 613a91e37331eab927350fe1d45550809c255471 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 15:52:39 -0400 Subject: [PATCH 11/64] model conversions and keyword internal --- NAMESPACE | 1 - R/boost_tree.R | 2 + R/decision_tree.R | 1 + R/logistic_reg_data.R | 830 +++++++++++++++++++++++--------------- R/misc.R | 2 +- R/mlp.R | 186 +++++++++ R/mlp_data.R | 558 +++++++++++++------------ R/multinom_reg_data.R | 339 ++++++++++------ R/nearest_neighbor_data.R | 255 ++++++++---- R/nullmodel_data.R | 186 ++++++--- man/C5.0_train.Rd | 1 + man/keras_mlp.Rd | 3 +- man/rpart_train.Rd | 1 + man/xgb_train.Rd | 1 + 14 files changed, 1475 insertions(+), 891 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 7a2834644..425c9ce72 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -202,7 +202,6 @@ importFrom(stats,na.omit) importFrom(stats,na.pass) importFrom(stats,predict) importFrom(stats,qnorm) -importFrom(stats,qt) importFrom(stats,quantile) importFrom(stats,setNames) importFrom(stats,terms) diff --git a/R/boost_tree.R b/R/boost_tree.R index 620648fa0..57f6d5ed3 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -261,6 +261,7 @@ check_args.boost_tree <- function(object) { #' @param subsample Subsampling proportion of rows. #' @param ... Other options to pass to `xgb.train`. #' @return A fitted `xgboost` object. +#' @keywords internal #' @export xgb_train <- function( x, y, @@ -432,6 +433,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { #' model in the printed output. #' @param ... Other arguments to pass. #' @return A fitted C5.0 model. +#' @keywords internal #' @export C5.0_train <- function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) { diff --git a/R/decision_tree.R b/R/decision_tree.R index b767b965d..afa807405 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -252,6 +252,7 @@ check_args.decision_tree <- function(object) { #' those cases. #' @param ... Other arguments to pass to either `rpart` or `rpart.control`. #' @return A fitted rpart model. +#' @keywords internal #' @export rpart_train <- function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) { diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 6fa1e0c94..972230e16 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -1,354 +1,516 @@ +set_new_model("logistic_reg") -logistic_reg_arg_key <- data.frame( - glm = c( NA, NA), - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - stan = c( NA, NA), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("logistic_reg", "classification") + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "glm") +set_dependency("logistic_reg", "glm", "stats") + +set_fit( + mod = "logistic_reg", + eng = "glm", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "stats", fun = "glm"), + defaults = list(family = expr(stats::binomial)) + ) ) -logistic_reg_modes <- "classification" +set_pred( + mod = "logistic_reg", + eng = "glm", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = prob_to_class_2, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) -logistic_reg_engines <- data.frame( - glm = TRUE, - glmnet = TRUE, - spark = TRUE, - stan = TRUE, - keras = TRUE, - row.names = c("classification") +set_pred( + mod = "logistic_reg", + eng = "glm", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + mod = "logistic_reg", + eng = "glm", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) -#' @importFrom stats qt -logistic_reg_glm_data <- - list( - libs = "stats", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "stats", fun = "glm"), - defaults = - list( - family = expr(stats::binomial) - ) - ), - class = list( - pre = NULL, - post = prob_to_class_2, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" +set_pred( + mod = "logistic_reg", + eng = "glm", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + const <- + qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res_2 <- + tibble( + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res_2 <- - tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$confint$extras$std_error) - res$.std_error <- results$se.fit - res - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - se.fit = TRUE, - type = "link" - ) - ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$confint$extras$std_error) + res$.std_error <- results$se.fit + res + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + se.fit = TRUE, + type = "link" + ) ) +) -# Note: For glmnet, you will need to make model-specific predict methods. -# See logistic_reg.R -logistic_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "binomial" - ) - ), - class = list( - pre = NULL, - post = organize_glmnet_class, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - classprob = list( - pre = NULL, - post = organize_glmnet_prob, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) - ) +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "glmnet") +set_dependency("logistic_reg", "glmnet", "glmnet") + +set_model_arg( + mod = "logistic_reg", + eng = "glmnet", + val = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + submodels = TRUE +) + +set_model_arg( + mod = "logistic_reg", + eng = "glmnet", + val = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + submodels = FALSE +) + +set_fit( + mod = "logistic_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "binomial") ) +) -logistic_reg_stan_data <- - list( - libs = "rstanarm", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rstanarm", fun = "stan_glm"), - defaults = - list( - family = expr(stats::binomial) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - x <- object$fit$family$linkinv(x) - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - unname(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- object$fit$family$linkinv(x) - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - res_2 <- - tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level, - lower = FALSE - ), - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$confint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_linpred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - transform = TRUE, - seed = expr(sample.int(10^5, 1)) - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - res_2 <- - tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level, - lower = FALSE - ), - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$predint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ) + +set_pred( + mod = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = organize_glmnet_class, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) ) +) +set_pred( + mod = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = organize_glmnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) + ) +) -logistic_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "binomial" - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) +set_pred( + mod = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "spark") +set_dependency("logistic_reg", "spark", "sparklyr") + +set_model_arg( + mod = "logistic_reg", + eng = "glmnet", + val = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + submodels = TRUE +) -logistic_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) +set_model_arg( + mod = "logistic_reg", + eng = "glmnet", + val = "elastic_net_param", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + submodels = FALSE +) + +set_fit( + mod = "logistic_reg", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), + defaults = + list( + family = "binomial" + ) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "keras") +set_model_engine("logistic_reg", "regression", "keras") +set_dependency("logistic_reg", "keras", "keras") +set_dependency("logistic_reg", "keras", "magrittr") + +set_model_arg( + mod = "logistic_reg", + eng = "keras", + val = "decay", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + submodels = FALSE +) + +set_fit( + mod = "logistic_reg", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + mod = "logistic_reg", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "stan") +set_model_engine("logistic_reg", "regression", "stan") +set_dependency("logistic_reg", "stan", "rstanarm") + +set_fit( + mod = "logistic_reg", + eng = "stan", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rstanarm", fun = "stan_glm"), + defaults = list(family = expr(stats::binomial)) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "stan", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- object$fit$family$linkinv(x) + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + unname(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "stan", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- object$fit$family$linkinv(x) + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + + +set_pred( + mod = "logistic_reg", + eng = "stan", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "stan", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res_2 <- + tibble( + lo = + convert_stan_interval( + results, + level = object$spec$method$confint$extras$level + ), + hi = + convert_stan_interval( + results, + level = object$spec$method$confint$extras$level, + lower = FALSE + ), ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$confint$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_linpred"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + transform = TRUE, + seed = expr(sample.int(10^5, 1)) + ) + ) +) + +set_pred( + mod = "logistic_reg", + eng = "stan", + mode = "classification", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + res_2 <- + tibble( + lo = + convert_stan_interval( + results, + level = object$spec$method$predint$extras$level + ), + hi = + convert_stan_interval( + results, + level = object$spec$method$predint$extras$level, + lower = FALSE + ), ) - ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$predint$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) +) diff --git a/R/misc.R b/R/misc.R index 3885e7817..e04842c02 100644 --- a/R/misc.R +++ b/R/misc.R @@ -201,7 +201,7 @@ update_dot_check <- function(...) { # ------------------------------------------------------------------------------ new_model_spec <- function(cls, args, eng_args, mode, method, engine) { - spec_modes <- get(paste0(cls, "_modes")) + spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) if (!(mode %in% spec_modes)) stop("`mode` should be one of: ", paste0("'", spec_modes, "'", collapse = ", "), diff --git a/R/mlp.R b/R/mlp.R index c8c94054e..3fe96631f 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -229,3 +229,189 @@ check_args.mlp <- function(object) { invisible(object) } + +# keras wrapper for feed-forward nnet + +class2ind <- function (x, drop2nd = FALSE) { + if (!is.factor(x)) + stop("`x` should be a factor") + y <- model.matrix( ~ x - 1) + colnames(y) <- gsub("^x", "", colnames(y)) + attributes(y)$assign <- NULL + attributes(y)$contrasts <- NULL + if (length(levels(x)) == 2 & drop2nd) { + y <- y[, 1] + } + y +} + + +# ------------------------------------------------------------------------------ + +#' Simple interface to MLP models via keras +#' +#' Instead of building a `keras` model sequentially, `keras_mlp` can be used to +#' create a feedforward network with a single hidden layer. Regularization is +#' via either weight decay or dropout. +#' +#' @param x A data frame or matrix of predictors +#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. +#' @param hidden_units An integer for the number of hidden units. +#' @param decay A non-negative real number for the amount of weight decay. Either +#' this parameter _or_ `dropout` can specified. +#' @param dropout The proportion of parameters to set to zero. Either +#' this parameter _or_ `decay` can specified. +#' @param epochs An integer for the number of passes through the data. +#' @param act A character string for the type of activation function between layers. +#' @param seeds A vector of three positive integers to control randomness of the +#' calculations. +#' @param ... Currently ignored. +#' @return A `keras` model object. +#' @keywords internal +#' @export +keras_mlp <- + function(x, y, + hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax", + seeds = sample.int(10^5, size = 3), + ...) { + + if(decay > 0 & dropout > 0) + stop("Please use either dropoput or weight decay.", call. = FALSE) + + if (!is.matrix(x)) + x <- as.matrix(x) + + if(is.character(y)) + y <- as.factor(y) + factor_y <- is.factor(y) + + if (factor_y) + y <- class2ind(y) + else { + if (isTRUE(ncol(y) > 1)) + y <- as.matrix(y) + else + y <- matrix(y, ncol = 1) + } + + model <- keras::keras_model_sequential() + if(decay > 0) { + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_regularizer = keras::regularizer_l2(decay), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) + } else { + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) + } + if(dropout > 0) + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) %>% + keras::layer_dropout(rate = dropout, seed = seeds[2]) + + if (factor_y) + model <- model %>% + keras::layer_dense( + units = ncol(y), + activation = 'softmax', + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) + else + model <- model %>% + keras::layer_dense( + units = ncol(y), + activation = 'linear', + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) + + arg_values <- parse_keras_args(...) + compile_call <- expr( + keras::compile(object = model) + ) + if(!any(names(arg_values$compile) == "loss")) + compile_call$loss <- + if(factor_y) "binary_crossentropy" else "mse" + if(!any(names(arg_values$compile) == "optimizer")) + compile_call$optimizer <- "adam" + for(arg in names(arg_values$compile)) + compile_call[[arg]] <- arg_values$compile[[arg]] + + model <- eval_tidy(compile_call) + + fit_call <- expr( + keras::fit(object = model) + ) + fit_call$x <- quote(x) + fit_call$y <- quote(y) + fit_call$epochs <- epochs + for(arg in names(arg_values$fit)) + fit_call[[arg]] <- arg_values$fit[[arg]] + + history <- eval_tidy(fit_call) + model + } + + +nnet_softmax <- function(results, object) { + if (ncol(results) == 1) + results <- cbind(1 - results, results) + + results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) + results <- as_tibble(t(results)) + names(results) <- paste0(".pred_", object$lvl) + results +} + +parse_keras_args <- function(...) { + exclusions <- c("object", "x", "y", "validation_data", "epochs") + fit_args <- c( + 'batch_size', + 'verbose', + 'callbacks', + 'view_metrics', + 'validation_split', + 'validation_data', + 'shuffle', + 'class_weight', + 'sample_weight', + 'initial_epoch', + 'steps_per_epoch', + 'validation_steps' + ) + compile_args <- c( + 'optimizer', + 'loss', + 'metrics', + 'loss_weights', + 'sample_weight_mode', + 'weighted_metrics', + 'target_tensors' + ) + dots <- list(...) + dots <- dots[!(names(dots) %in% exclusions)] + + list( + fit = dots[names(dots) %in% fit_args], + compile = dots[names(dots) %in% compile_args] + ) +} + +mlp_num_weights <- function(p, hidden_units, classes) { + ((p+1) * hidden_units) + ((hidden_units+1) * classes) +} + + diff --git a/R/mlp_data.R b/R/mlp_data.R index 724035e7e..89907fe7c 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -1,312 +1,306 @@ -mlp_arg_key <- data.frame( - nnet = c("size", "decay", NA_character_, "maxit", NA_character_), - keras = c("hidden_units", "penalty", "dropout", "epochs", "activation"), - stringsAsFactors = FALSE, - row.names = c("hidden_units", "penalty", "dropout", "epochs", "activation") -) - -mlp_modes <- c("classification", "regression", "unknown") +set_new_model("mlp") -mlp_engines <- data.frame( - nnet = c(TRUE, TRUE, FALSE), - keras = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") -) +set_model_mode("mlp", "classification") +set_model_mode("mlp", "regression") # ------------------------------------------------------------------------------ -mlp_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ) - ) +set_model_engine("mlp", "classification", "keras") +set_model_engine("mlp", "regression", "keras") +set_dependency("mlp", "keras", "keras") +set_dependency("mlp", "keras", "magrittr") +set_model_arg( + mod = "mlp", + eng = "keras", + val = "hidden_units", + original = "hidden_units", + func = list(pkg = "dials", fun = "hidden_units"), + submodels = FALSE +) +set_model_arg( + mod = "mlp", + eng = "keras", + val = "penalty", + original = "penalty", + func = list(pkg = "dials", fun = "weight_decay"), + submodels = FALSE +) +set_model_arg( + mod = "mlp", + eng = "keras", + val = "dropout", + original = "dropout", + func = list(pkg = "dials", fun = "dropout"), + submodels = FALSE +) +set_model_arg( + mod = "mlp", + eng = "keras", + val = "epochs", + original = "epochs", + func = list(pkg = "dials", fun = "epochs"), + submodels = FALSE +) +set_model_arg( + mod = "mlp", + eng = "keras", + val = "activation", + original = "activation", + func = list(pkg = "dials", fun = "activation"), + submodels = FALSE +) -nnet_softmax <- function(results, object) { - if (ncol(results) == 1) - results <- cbind(1 - results, results) - - results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) - results <- as_tibble(t(results)) - names(results) <- paste0(".pred_", object$lvl) - results -} -mlp_nnet_data <- - list( - libs = "nnet", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "nnet", fun = "nnet"), - defaults = list(trace = FALSE) - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = nnet_softmax, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_fit( + mod = "mlp", + eng = "keras", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list() ) +) +set_fit( + mod = "mlp", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list() + ) +) -# ------------------------------------------------------------------------------ - -# keras wrapper for feed-forward nnet - -class2ind <- function (x, drop2nd = FALSE) { - if (!is.factor(x)) - stop("`x` should be a factor") - y <- model.matrix( ~ x - 1) - colnames(y) <- gsub("^x", "", colnames(y)) - attributes(y)$assign <- NULL - attributes(y)$contrasts <- NULL - if (length(levels(x)) == 2 & drop2nd) { - y <- y[, 1] - } - y -} - - -#' Simple interface to MLP models via keras -#' -#' Instead of building a `keras` model sequentially, `keras_mlp` can be used to -#' create a feedforward network with a single hidden layer. Regularization is -#' via either weight decay or dropout. -#' -#' @param x A data frame or matrix of predictors -#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. -#' @param hidden_units An integer for the number of hidden units. -#' @param decay A non-negative real number for the amount of weight decay. Either -#' this parameter _or_ `dropout` can specified. -#' @param dropout The proportion of parameters to set to zero. Either -#' this parameter _or_ `decay` can specified. -#' @param epochs An integer for the number of passes through the data. -#' @param act A character string for the type of activation function between layers. -#' @param seeds A vector of three positive integers to control randomness of the -#' calculations. -#' @param ... Currently ignored. -#' @return A `keras` model object. -#' @export -keras_mlp <- - function(x, y, - hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax", - seeds = sample.int(10^5, size = 3), - ...) { - - if(decay > 0 & dropout > 0) - stop("Please use either dropoput or weight decay.", call. = FALSE) - - if (!is.matrix(x)) - x <- as.matrix(x) +set_pred( + mod = "mlp", + eng = "keras", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) - if(is.character(y)) - y <- as.factor(y) - factor_y <- is.factor(y) +set_pred( + mod = "mlp", + eng = "keras", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) - if (factor_y) - y <- class2ind(y) - else { - if (isTRUE(ncol(y) > 1)) - y <- as.matrix(y) - else - y <- matrix(y, ncol = 1) - } +) - model <- keras::keras_model_sequential() - if(decay > 0) { - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_regularizer = keras::regularizer_l2(decay), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) - } else { - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) - } - if(dropout > 0) - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) %>% - keras::layer_dropout(rate = dropout, seed = seeds[2]) +set_pred( + mod = "mlp", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) - if (factor_y) - model <- model %>% - keras::layer_dense( - units = ncol(y), - activation = 'softmax', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) +set_pred( + mod = "mlp", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) ) - else - model <- model %>% - keras::layer_dense( - units = ncol(y), - activation = 'linear', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) +) + +set_pred( + mod = "mlp", + eng = "keras", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) ) + ) +) - arg_values <- parse_keras_args(...) - compile_call <- expr( - keras::compile(object = model) - ) - if(!any(names(arg_values$compile) == "loss")) - compile_call$loss <- - if(factor_y) "binary_crossentropy" else "mse" - if(!any(names(arg_values$compile) == "optimizer")) - compile_call$optimizer <- "adam" - for(arg in names(arg_values$compile)) - compile_call[[arg]] <- arg_values$compile[[arg]] +# ------------------------------------------------------------------------------ - model <- eval_tidy(compile_call) +set_model_engine("mlp", "classification", "nnet") +set_model_engine("mlp", "regression", "nnet") +set_dependency("mlp", "nnet", "nnet") - fit_call <- expr( - keras::fit(object = model) - ) - fit_call$x <- quote(x) - fit_call$y <- quote(y) - fit_call$epochs <- epochs - for(arg in names(arg_values$fit)) - fit_call[[arg]] <- arg_values$fit[[arg]] +set_model_arg( + mod = "mlp", + eng = "nnet", + val = "hidden_units", + original = "size", + func = list(pkg = "dials", fun = "hidden_units"), + submodels = FALSE +) +set_model_arg( + mod = "mlp", + eng = "nnet", + val = "penalty", + original = "penalty", + func = list(pkg = "dials", fun = "weight_decay"), + submodels = FALSE +) - history <- eval_tidy(fit_call) - model - } +set_fit( + mod = "mlp", + eng = "nnet", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "nnet"), + defaults = list(trace = FALSE) + ) +) -parse_keras_args <- function(...) { - exclusions <- c("object", "x", "y", "validation_data", "epochs") - fit_args <- c( - 'batch_size', - 'verbose', - 'callbacks', - 'view_metrics', - 'validation_split', - 'validation_data', - 'shuffle', - 'class_weight', - 'sample_weight', - 'initial_epoch', - 'steps_per_epoch', - 'validation_steps' +set_fit( + mod = "mlp", + eng = "nnet", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "nnet"), + defaults = list(trace = FALSE) ) - compile_args <- c( - 'optimizer', - 'loss', - 'metrics', - 'loss_weights', - 'sample_weight_mode', - 'weighted_metrics', - 'target_tensors' +) + +set_pred( + mod = "mlp", + eng = "nnet", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) - dots <- list(...) - dots <- dots[!(names(dots) %in% exclusions)] +) - list( - fit = dots[names(dots) %in% fit_args], - compile = dots[names(dots) %in% compile_args] +set_pred( + mod = "mlp", + eng = "nnet", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) -} -mlp_num_weights <- function(p, hidden_units, classes) - ((p+1) * hidden_units) + ((hidden_units+1) * classes) +) +set_pred( + mod = "mlp", + eng = "nnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) + ) +) +set_pred( + mod = "mlp", + eng = "nnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = nnet_softmax, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) +set_pred( + mod = "mlp", + eng = "nnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 186291003..e356e36ab 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -1,138 +1,229 @@ +set_new_model("multinom_reg") -multinom_reg_arg_key <- data.frame( - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("multinom_reg", "classification") + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "glmnet") +set_dependency("multinom_reg", "glmnet", "glmnet") + +set_model_arg( + mod = "multinom_reg", + eng = "glmnet", + val = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + submodels = TRUE +) + +set_model_arg( + mod = "multinom_reg", + eng = "glmnet", + val = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + submodels = FALSE +) + +set_fit( + mod = "multinom_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "multinomial") + ) +) + + +set_pred( + mod = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "class", + value = list( + pre = check_glmnet_lambda, + post = organize_multnet_class, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "class", + s = quote(object$spec$args$penalty) + ) + ) ) -multinom_reg_modes <- "classification" +set_pred( + mod = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "prob", + value = list( + pre = check_glmnet_lambda, + post = organize_multnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) + ) +) -multinom_reg_engines <- data.frame( - glmnet = TRUE, - spark = TRUE, - keras = TRUE, - row.names = c("classification") +set_pred( + mod = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) + ) ) # ------------------------------------------------------------------------------ -multinom_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "multinomial" - ) - ), - class = list( - pre = check_glmnet_lambda, - post = organize_multnet_class, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "class", - s = quote(object$spec$args$penalty) - ) - ), - classprob = list( - pre = check_glmnet_lambda, - post = organize_multnet_prob, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) - ) +set_model_engine("multinom_reg", "classification", "spark") +set_dependency("multinom_reg", "spark", "sparklyr") + +set_model_arg( + mod = "multinom_reg", + eng = "glmnet", + val = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + submodels = TRUE +) + +set_model_arg( + mod = "multinom_reg", + eng = "glmnet", + val = "elastic_net_param", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + submodels = FALSE +) + +set_fit( + mod = "multinom_reg", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), + defaults = list(family = "multinomial") ) +) -multinom_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "multinomial" - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) +set_pred( + mod = "multinom_reg", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) +) -multinom_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ) +set_pred( + mod = "multinom_reg", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "keras") +set_model_engine("multinom_reg", "regression", "keras") +set_dependency("multinom_reg", "keras", "keras") +set_dependency("multinom_reg", "keras", "magrittr") + +set_model_arg( + mod = "multinom_reg", + eng = "keras", + val = "decay", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + submodels = FALSE +) + + +set_fit( + mod = "multinom_reg", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + mod = "multinom_reg", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list(object = quote(object$fit), + x = quote(as.matrix(new_data))) + ) +) + +set_pred( + mod = "multinom_reg", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list(object = quote(object$fit), + x = quote(as.matrix(new_data))) + ) +) diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 0191d8614..98f086277 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -1,90 +1,179 @@ -nearest_neighbor_arg_key <- data.frame( - kknn = c("ks", "kernel", "distance"), - row.names = c("neighbors", "weight_func", "dist_power"), - stringsAsFactors = FALSE + +set_new_model("nearest_neighbor") + +set_model_mode("nearest_neighbor", "classification") +set_model_mode("nearest_neighbor", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("nearest_neighbor", "classification", "kknn") +set_model_engine("nearest_neighbor", "regression", "kknn") +set_dependency("nearest_neighbor", "kknn", "kknn") + +set_model_arg( + mod = "nearest_neighbor", + eng = "kknn", + val = "num_terms", + original = "nprune", + func = list(pkg = "dials", fun = "num_terms"), + submodels = FALSE +) +set_model_arg( + mod = "nearest_neighbor", + eng = "kknn", + val = "neighbors", + original = "ks", + func = list(pkg = "dials", fun = "neighbors"), + submodels = FALSE +) +set_model_arg( + mod = "nearest_neighbor", + eng = "kknn", + val = "weight_func", + original = "kernel", + func = list(pkg = "dials", fun = "weight_func"), + submodels = FALSE +) +set_model_arg( + mod = "nearest_neighbor", + eng = "kknn", + val = "distance", + original = "dist_power", + func = list(pkg = "dials", fun = "distance"), + submodels = FALSE ) -nearest_neighbor_modes <- c("classification", "regression", "unknown") +set_fit( + mod = "nearest_neighbor", + eng = "kknn", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "kmax"), # kmax is not allowed + func = c(pkg = "kknn", fun = "train.kknn"), + defaults = list() + ) +) -nearest_neighbor_engines <- data.frame( - kknn = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_fit( + mod = "nearest_neighbor", + eng = "kknn", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "kmax"), # kmax is not allowed + func = c(pkg = "kknn", fun = "train.kknn"), + defaults = list() + ) ) -# ------------------------------------------------------------------------------ +set_pred( + mod = "nearest_neighbor", + eng = "kknn", + mode = "regression", + type = "numeric", + value = list( + # seems unnecessary here as the predict_numeric catches it based on the + # model mode + pre = function(x, object) { + if (object$fit$response != "continuous") { + stop("`kknn` model does not appear to use numeric predictions. Was ", + "the model fit with a continuous response variable?", + call. = FALSE) + } + x + }, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) + +set_pred( + mod = "nearest_neighbor", + eng = "kknn", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + mod = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "class", + value = list( + pre = function(x, object) { + if (!(object$fit$response %in% c("ordinal", "nominal"))) { + stop("`kknn` model does not appear to use class predictions. Was ", + "the model fit with a factor response variable?", + call. = FALSE) + } + x + }, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) + +set_pred( + mod = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "prob", + value = list( + pre = function(x, object) { + if (!(object$fit$response %in% c("ordinal", "nominal"))) { + stop("`kknn` model does not appear to use class predictions. Was ", + "the model fit with a factor response variable?", + call. = FALSE) + } + x + }, + post = function(result, object) as_tibble(result), + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) -nearest_neighbor_kknn_data <- - list( - libs = "kknn", - fit = list( - interface = "formula", - protect = c("formula", "data", "kmax"), # kmax is not allowed - func = c(pkg = "kknn", fun = "train.kknn"), - defaults = list() - ), - numeric = list( - # seems unnecessary here as the predict_numeric catches it based on the - # model mode - pre = function(x, object) { - if (object$fit$response != "continuous") { - stop("`kknn` model does not appear to use numeric predictions. Was ", - "the model fit with a continuous response variable?", - call. = FALSE) - } - x - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - class = list( - pre = function(x, object) { - if (!(object$fit$response %in% c("ordinal", "nominal"))) { - stop("`kknn` model does not appear to use class predictions. Was ", - "the model fit with a factor response variable?", - call. = FALSE) - } - x - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - classprob = list( - pre = function(x, object) { - if (!(object$fit$response %in% c("ordinal", "nominal"))) { - stop("`kknn` model does not appear to use class predictions. Was ", - "the model fit with a factor response variable?", - call. = FALSE) - } - x - }, - post = function(result, object) as_tibble(result), - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_pred( + mod = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) +) diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index 80380a98a..e09bf398a 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -1,72 +1,128 @@ -null_model_arg_key <- data.frame( - parsnip = NULL, - row.names = NULL, - stringsAsFactors = FALSE +set_new_model("null_model") + +set_model_mode("null_model", "classification") +set_model_mode("null_model", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("null_model", "classification", "parsnip") +set_model_engine("null_model", "regression", "parsnip") +set_dependency("null_model", "parsnip", "parsnip") + +set_fit( + mod = "null_model", + eng = "parsnip", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(fun = "nullmodel"), + defaults = list() + ) ) -null_model_modes <- c("classification", "regression", "unknown") +set_fit( + mod = "null_model", + eng = "parsnip", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(fun = "nullmodel"), + defaults = list() + ) +) -null_model_engines <- data.frame( - parsnip = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_pred( + mod = "null_model", + eng = "parsnip", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + mod = "null_model", + eng = "parsnip", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) + ) +) -null_model_parsnip_data <- - list( - libs = "parsnip", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(fun = "nullmodel"), - defaults = list() - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - str(as_tibble(x)) - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) - ), - raw = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "raw" - ) +set_pred( + mod = "null_model", + eng = "parsnip", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" ) - ) + ) +) + +set_pred( + mod = "null_model", + eng = "parsnip", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + str(as_tibble(x)) + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + mod = "null_model", + eng = "parsnip", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "raw" + ) + ) +) + diff --git a/man/C5.0_train.Rd b/man/C5.0_train.Rd index b81c415e3..9a3e2f0d1 100644 --- a/man/C5.0_train.Rd +++ b/man/C5.0_train.Rd @@ -41,3 +41,4 @@ A fitted C5.0 model. \pkg{C50} package that fits tree-based models where all of the model arguments are in the main function. } +\keyword{internal} diff --git a/man/keras_mlp.Rd b/man/keras_mlp.Rd index db7ef268c..4972655aa 100644 --- a/man/keras_mlp.Rd +++ b/man/keras_mlp.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/mlp_data.R +% Please edit documentation in R/mlp.R \name{keras_mlp} \alias{keras_mlp} \title{Simple interface to MLP models via keras} @@ -38,3 +38,4 @@ Instead of building a \code{keras} model sequentially, \code{keras_mlp} can be u create a feedforward network with a single hidden layer. Regularization is via either weight decay or dropout. } +\keyword{internal} diff --git a/man/rpart_train.Rd b/man/rpart_train.Rd index cb4e6763d..6ce88fb66 100644 --- a/man/rpart_train.Rd +++ b/man/rpart_train.Rd @@ -42,3 +42,4 @@ A fitted rpart model. \code{rpart_train} is a wrapper for \code{rpart()} tree-based models where all of the model arguments are in the main function. } +\keyword{internal} diff --git a/man/xgb_train.Rd b/man/xgb_train.Rd index b3ed65952..5eeab7841 100644 --- a/man/xgb_train.Rd +++ b/man/xgb_train.Rd @@ -38,3 +38,4 @@ A fitted \code{xgboost} object. \code{xgb_train} is a wrapper for \code{xgboost} tree-based models where all of the model arguments are in the main function. } +\keyword{internal} From a4c09557263e6109d18c9f5c7459b97aa884db5d Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 17:20:20 -0400 Subject: [PATCH 12/64] temp commented out until new model structures are ready --- R/zzz.R | 118 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/R/zzz.R b/R/zzz.R index dbee4e8cc..5978630f9 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,59 +1,59 @@ -## nocov start - -data_obj <- ls(pattern = "_data$") -data_obj <- data_obj[data_obj != "prepare_data"] - -#' @importFrom purrr map_dfr -#' @importFrom tibble as_tibble -data_names <- - map_dfr( - data_obj, - function(x) { - module <- names(get(x)) - if (length(module) > 1) { - module <- table(module) - module <- as_tibble(module) - module$object <- x - module - } else - module <- NULL - module - } - ) - -if(any(data_names$n > 1)) { - print(data_names[data_names$n > 1,]) - stop("Some models have duplicate module names.") -} -rm(data_names) - -# ------------------------------------------------------------------------------ - -engine_objects <- ls(pattern = "_engines$") -engine_objects <- engine_objects[engine_objects != "possible_engines"] - -#' @importFrom utils stack -get_engine_info <- function(x) { - y <- x - y <- get(y) - z <- stack(y) - z$mode <- rownames(y) - z$model <- gsub("_engines$", "", x) - z$object <- x - z <- z[z$values,] - z <- z[z$mode != "unknown",] - z$values <- NULL - names(z)[1] <- "engine" - z$engine <- as.character(z$engine) - z -} - -engine_info <- - purrr::map_df( - parsnip:::engine_objects, - get_engine_info -) - -rm(engine_objects) - -## nocov end +#' ## nocov start +#' +#' data_obj <- ls(pattern = "_data$") +#' data_obj <- data_obj[data_obj != "prepare_data"] +#' +#' #' @importFrom purrr map_dfr +#' #' @importFrom tibble as_tibble +#' data_names <- +#' map_dfr( +#' data_obj, +#' function(x) { +#' module <- names(get(x)) +#' if (length(module) > 1) { +#' module <- table(module) +#' module <- as_tibble(module) +#' module$object <- x +#' module +#' } else +#' module <- NULL +#' module +#' } +#' ) +#' +#' if(any(data_names$n > 1)) { +#' print(data_names[data_names$n > 1,]) +#' stop("Some models have duplicate module names.") +#' } +#' rm(data_names) +#' +#' # ------------------------------------------------------------------------------ +#' +#' engine_objects <- ls(pattern = "_engines$") +#' engine_objects <- engine_objects[engine_objects != "possible_engines"] +#' +#' #' @importFrom utils stack +#' get_engine_info <- function(x) { +#' y <- x +#' y <- get(y) +#' z <- stack(y) +#' z$mode <- rownames(y) +#' z$model <- gsub("_engines$", "", x) +#' z$object <- x +#' z <- z[z$values,] +#' z <- z[z$mode != "unknown",] +#' z$values <- NULL +#' names(z)[1] <- "engine" +#' z$engine <- as.character(z$engine) +#' z +#' } +#' +#' engine_info <- +#' purrr::map_df( +#' parsnip:::engine_objects, +#' get_engine_info +#' ) +#' +#' rm(engine_objects) +#' +#' ## nocov end From 5dc131d5ee8e4d02db182339ccdae942bc3a7ef9 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 17:20:45 -0400 Subject: [PATCH 13/64] added some replacement functions for new model structure --- R/translate.R | 55 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/R/translate.R b/R/translate.R index fd014f770..aaf0317ad 100644 --- a/R/translate.R +++ b/R/translate.R @@ -19,11 +19,11 @@ #' This function can be useful when you need to understand how #' `parsnip` goes from a generic model specific to a model fitting #' function. -#' +#' #' **Note**: this function is used internally and users should only use it #' to understand what the underlying syntax would be. It should not be used -#' to modify the model specification. -#' +#' to modify the model specification. +#' #' @examples #' lm_spec <- linear_reg(penalty = 0.01) #' @@ -144,3 +144,52 @@ check_mode <- function(object, lvl) { } object } + +# ------------------------------------------------------------------------------ +# new code for revised model data structures + +get_model_spec <- function(model, mode, engine) { + m_env <- get_model_env() + env_obj <- env_names(m_env) + env_obj <- grep(model, env_obj, value = TRUE) + + res <- list() + res$libs <- + env_get(m_env, paste0(model, "_pkgs")) %>% + purrr::pluck("pkg") %>% + purrr::pluck(1) + + res$fit <- + env_get(m_env, paste0(model, "_fit")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::pull(value) %>% + purrr:::pluck(1) + + pred_code <- + env_get(m_env, paste0(model, "_predict")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::select(-engine, -mode) + + res$pred <- pred_code[["value"]] + names(res$pred) <- pred_code$type + + res +} + +get_args <- function(model, engine) { + m_env <- get_model_env() + env_get(m_env, paste0(model, "_args")) %>% + dplyr::select(-engine) +} + +# to replace harmonize +unionize <- function(args, key) { + parsn <- tibble(parsnip = names(args), order = seq_along(args)) + merged <- + dplyr::left_join(parsn, key, by = "parsnip") %>% + dplyr::arrange(order) + # TODO correct for bad merge? + + names(args) <- merged$original + args[!is.na(merged$original)] +} From 7d9cca784a7f1aa08277680632e1494a1eddb007 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 17:24:30 -0400 Subject: [PATCH 14/64] disabled Rd Sexpr until new data structure ready --- NAMESPACE | 1 - R/aaa_models.R | 4 ++-- R/boost_tree.R | 10 +++++----- R/decision_tree.R | 10 +++++----- R/linear_reg.R | 10 +++++----- R/logistic_reg.R | 10 +++++----- R/mars.R | 4 ++-- R/mlp.R | 8 ++++---- R/multinom_reg.R | 6 +++--- R/nearest_neighbor.R | 2 +- R/nullmodel.R | 4 ++-- R/rand_forest.R | 12 ++++++------ R/surv_reg.R | 4 ++-- R/svm_poly.R | 4 ++-- R/svm_rbf.R | 4 ++-- man/boost_tree.Rd | 10 ---------- man/decision_tree.Rd | 10 ---------- man/linear_reg.Rd | 10 ---------- man/logistic_reg.Rd | 10 ---------- man/mars.Rd | 4 ---- man/mlp.Rd | 8 -------- man/multinom_reg.Rd | 6 ------ man/nearest_neighbor.Rd | 2 -- man/null_model.Rd | 4 ---- man/rand_forest.Rd | 12 ------------ man/surv_reg.Rd | 4 ---- man/svm_poly.Rd | 4 ---- man/svm_rbf.Rd | 4 ---- 28 files changed, 46 insertions(+), 135 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 425c9ce72..ca7a96068 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -215,4 +215,3 @@ importFrom(utils,capture.output) importFrom(utils,getFromNamespace) importFrom(utils,globalVariables) importFrom(utils,head) -importFrom(utils,stack) diff --git a/R/aaa_models.R b/R/aaa_models.R index 2851da8af..d8e340b70 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -38,7 +38,7 @@ pred_types <- #' @export get_model_env <- function() { current <- utils::getFromNamespace("parsnip", ns = "parsnip") - # current <- get("parsnip") + # current <- parsnip current } @@ -581,4 +581,4 @@ show_model_info <- function(mod) { invisible(NULL) } -# ------------------------------------------------------------------------------ + diff --git a/R/boost_tree.R b/R/boost_tree.R index 57f6d5ed3..9c7f84609 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -67,23 +67,23 @@ #' #' \pkg{xgboost} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "xgboost")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "xgboost")} #' #' \pkg{xgboost} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "xgboost")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "xgboost")} #' #' \pkg{C5.0} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "C5.0")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "C5.0")} #' #' \pkg{spark} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "spark")} #' #' \pkg{spark} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "spark")} #' #' @note For models created using the spark engine, there are #' several differences to consider. First, only the formula diff --git a/R/decision_tree.R b/R/decision_tree.R index afa807405..5032f990f 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -52,23 +52,23 @@ #' #' \pkg{rpart} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} #' #' \pkg{rpart} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} #' #' \pkg{C5.0} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} #' #' \pkg{spark} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} #' #' \pkg{spark} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} #' #' @note For models created using the spark engine, there are #' several differences to consider. First, only the formula diff --git a/R/linear_reg.R b/R/linear_reg.R index d3bc72c70..3f94a0cce 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -50,23 +50,23 @@ #' #' \pkg{lm} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "lm")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "lm")} #' #' \pkg{glmnet} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "glmnet")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "glmnet")} #' #' \pkg{stan} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")} #' #' \pkg{spark} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} #' #' \pkg{keras} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. This diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 566a0c4ad..62e6f22c1 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -48,23 +48,23 @@ #' #' \pkg{glm} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glm")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glm")} #' #' \pkg{glmnet} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glmnet")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glmnet")} #' #' \pkg{stan} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")} #' #' \pkg{spark} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} #' #' \pkg{keras} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. This diff --git a/R/mars.R b/R/mars.R index bfe7f6cbf..d3c1bfaae 100644 --- a/R/mars.R +++ b/R/mars.R @@ -44,11 +44,11 @@ #' #' \pkg{earth} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "classification"), "earth")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "classification"), "earth")} #' #' \pkg{earth} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "regression"), "earth")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "regression"), "earth")} #' #' Note that, when the model is fit, the \pkg{earth} package only has its #' namespace loaded. However, if `multi_predict` is used, the package is diff --git a/R/mlp.R b/R/mlp.R index 3fe96631f..d31b8726d 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -62,19 +62,19 @@ #' #' \pkg{keras} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "keras")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "keras")} #' #' \pkg{keras} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "keras")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "keras")} #' #' \pkg{nnet} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "nnet")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "nnet")} #' #' \pkg{nnet} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/R/multinom_reg.R b/R/multinom_reg.R index b9aa2f40c..3ff94c1c2 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -47,15 +47,15 @@ #' #' \pkg{glmnet} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} #' #' \pkg{spark} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} #' #' \pkg{keras} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. This diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 52c9a9f0e..eaa83188e 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -54,7 +54,7 @@ #' #' \pkg{kknn} (classification or regression) #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")} #' #' @note #' For `kknn`, the underlying modeling function used is a restricted diff --git a/R/nullmodel.R b/R/nullmodel.R index c37ecfead..59f86d051 100644 --- a/R/nullmodel.R +++ b/R/nullmodel.R @@ -147,11 +147,11 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { #' #' \pkg{parsnip} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")} #' #' \pkg{parsnip} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/R/rand_forest.R b/R/rand_forest.R index 352178e2b..23ec1c1f9 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -46,27 +46,27 @@ #' #' \pkg{ranger} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "ranger")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "ranger")} #' #' \pkg{ranger} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "ranger")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "ranger")} #' #' \pkg{randomForests} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "randomForest")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "randomForest")} #' #' \pkg{randomForests} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "randomForest")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "randomForest")} #' #' \pkg{spark} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "spark")} #' #' \pkg{spark} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "spark")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "spark")} #' #' For \pkg{ranger} confidence intervals, the intervals are #' constructed using the form `estimate +/- z * std_error`. For diff --git a/R/surv_reg.R b/R/surv_reg.R index 33ab8b47e..e2f39e14b 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -47,11 +47,11 @@ #' #' \pkg{flexsurv} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} #' #' \pkg{survreg} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} #' #' Note that `model = TRUE` is needed to produce quantile #' predictions when there is a stratification variable and can be diff --git a/R/svm_poly.R b/R/svm_poly.R index 5eb071950..a2b1e8fe0 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -44,11 +44,11 @@ #' #' \pkg{kernlab} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} #' #' \pkg{kernlab} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/R/svm_rbf.R b/R/svm_rbf.R index 0fe3d39a6..b0ab171ba 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -43,11 +43,11 @@ #' #' \pkg{kernlab} classification #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} #' #' \pkg{kernlab} regression #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} +# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/man/boost_tree.Rd b/man/boost_tree.Rd index 904313d00..6ebde7cca 100644 --- a/man/boost_tree.Rd +++ b/man/boost_tree.Rd @@ -112,23 +112,13 @@ fit calls are: \pkg{xgboost} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "xgboost")} - \pkg{xgboost} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "xgboost")} - \pkg{C5.0} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "C5.0")} - \pkg{spark} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "spark")} - \pkg{spark} regression - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "spark")} } \examples{ diff --git a/man/decision_tree.Rd b/man/decision_tree.Rd index 221bcdccb..4fbe4d18b 100644 --- a/man/decision_tree.Rd +++ b/man/decision_tree.Rd @@ -88,23 +88,13 @@ model, the template of the fit calls are:: \pkg{rpart} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} - \pkg{rpart} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} - \pkg{C5.0} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} - \pkg{spark} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} - \pkg{spark} regression - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} } \examples{ diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index b58c1d631..3eecc57fc 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -87,24 +87,14 @@ model, the template of the fit calls are: \pkg{lm} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "lm")} - \pkg{glmnet} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "glmnet")} - \pkg{stan} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")} - \pkg{spark} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} - \pkg{keras} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} - When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using the diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index 43aa599e9..4f23ceb2d 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -85,24 +85,14 @@ model, the template of the fit calls are: \pkg{glm} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glm")} - \pkg{glmnet} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glmnet")} - \pkg{stan} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")} - \pkg{spark} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} - \pkg{keras} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} - When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using the diff --git a/man/mars.Rd b/man/mars.Rd index b55c24241..5647bdcf0 100644 --- a/man/mars.Rd +++ b/man/mars.Rd @@ -67,12 +67,8 @@ model, the template of the fit calls are: \pkg{earth} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "classification"), "earth")} - \pkg{earth} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "regression"), "earth")} - Note that, when the model is fit, the \pkg{earth} package only has its namespace loaded. However, if \code{multi_predict} is used, the package is attached. diff --git a/man/mlp.Rd b/man/mlp.Rd index f52a60a80..0acb2655a 100644 --- a/man/mlp.Rd +++ b/man/mlp.Rd @@ -91,19 +91,11 @@ model, the template of the fit calls are: \pkg{keras} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "keras")} - \pkg{keras} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "keras")} - \pkg{nnet} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "nnet")} - \pkg{nnet} regression - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")} } \examples{ diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index 6f2b4af05..e3f8dd30d 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -84,16 +84,10 @@ model, the template of the fit calls are: \pkg{glmnet} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} - \pkg{spark} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} - \pkg{keras} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} - When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using the diff --git a/man/nearest_neighbor.Rd b/man/nearest_neighbor.Rd index 10a374297..123d8f9ca 100644 --- a/man/nearest_neighbor.Rd +++ b/man/nearest_neighbor.Rd @@ -66,8 +66,6 @@ model fit call. For this type of model, the template of the fit calls are: \pkg{kknn} (classification or regression) - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")} } \examples{ diff --git a/man/null_model.Rd b/man/null_model.Rd index b0930770b..0c221dcba 100644 --- a/man/null_model.Rd +++ b/man/null_model.Rd @@ -32,11 +32,7 @@ model, the template of the fit calls are: \pkg{parsnip} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")} - \pkg{parsnip} regression - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")} } \examples{ diff --git a/man/rand_forest.Rd b/man/rand_forest.Rd index 80c6c8028..73e40c618 100644 --- a/man/rand_forest.Rd +++ b/man/rand_forest.Rd @@ -82,28 +82,16 @@ model, the template of the fit calls are:: \pkg{ranger} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "ranger")} - \pkg{ranger} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "ranger")} - \pkg{randomForests} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "randomForest")} - \pkg{randomForests} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "randomForest")} - \pkg{spark} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "spark")} - \pkg{spark} regression -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "spark")} - For \pkg{ranger} confidence intervals, the intervals are constructed using the form \code{estimate +/- z * std_error}. For classification probabilities, these values can fall outside of diff --git a/man/surv_reg.Rd b/man/surv_reg.Rd index 60a275317..f6b307d12 100644 --- a/man/surv_reg.Rd +++ b/man/surv_reg.Rd @@ -70,12 +70,8 @@ model, the template of the fit calls are: \pkg{flexsurv} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} - \pkg{survreg} -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} - Note that \code{model = TRUE} is needed to produce quantile predictions when there is a stratification variable and can be overridden in other cases. diff --git a/man/svm_poly.Rd b/man/svm_poly.Rd index 314b55e49..2b31741b6 100644 --- a/man/svm_poly.Rd +++ b/man/svm_poly.Rd @@ -69,11 +69,7 @@ model, the template of the fit calls are:: \pkg{kernlab} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} - \pkg{kernlab} regression - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} } \examples{ diff --git a/man/svm_rbf.Rd b/man/svm_rbf.Rd index e815d9e38..d7b503207 100644 --- a/man/svm_rbf.Rd +++ b/man/svm_rbf.Rd @@ -67,11 +67,7 @@ model, the template of the fit calls are:: \pkg{kernlab} classification -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} - \pkg{kernlab} regression - -\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} } \examples{ From 05f7eff77f9c18f6f7ff78074bb45cac32248f5c Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 17:34:32 -0400 Subject: [PATCH 15/64] updated check_engine --- R/engines.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/engines.R b/R/engines.R index a1b32de5f..851e46755 100644 --- a/R/engines.R +++ b/R/engines.R @@ -15,9 +15,9 @@ specific_model <- function(x) { possible_engines <- function(object, ...) { - cls <- specific_model(object) - key_df <- get(paste(cls, "engines", sep = "_")) - colnames(key_df[object$mode, , drop = FALSE]) + m_env <- get_model_env() + engs <- rlang::env_get(m_env, specific_model(object)) + unique(engs$engine) } check_engine <- function(object) { From 53d40f96a3ea81dd25b605d585ca6e97209e6381 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 17:49:12 -0400 Subject: [PATCH 16/64] basic translate is working --- R/translate.R | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/R/translate.R b/R/translate.R index aaf0317ad..c357a71b0 100644 --- a/R/translate.R +++ b/R/translate.R @@ -53,25 +53,32 @@ translate.default <- function(x, engine = x$engine, ...) { if (is.null(engine)) stop("Please set an engine.", call. = FALSE) + mod_name <- specific_model(x) + x$engine <- engine x <- check_engine(x) + # TODOS + # what to do with unknown mode? + if (x$mode == "unknown") { + stop("Model code depends on the mode; please specify one.", call. = FALSE) + } + # set the classes. Is a constructor not being used? if (is.null(x$method)) - x <- get_method(x, engine, ...) + x$method <- get_model_spec(mod_name, x$mode, engine) - arg_key <- get_module(specific_model(x)) + arg_key <- get_args(mod_name, engine) # deharmonize primary arguments - actual_args <- deharmonize(x$args, arg_key, x$engine) + actual_args <- unionize(x$args, arg_key) # check secondary arguments to see if they are in the final # expression unless there are dots, warn if protected args are # being altered - eng_arg_key <- arg_key[[x$engine]] - x$eng_args <- check_eng_args(x$eng_args, x$method$fit, eng_arg_key) + x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original) # keep only modified args - modifed_args <- !vapply(actual_args, null_value, lgl(1)) + modifed_args <- purrr::map_lgl(actual_args, null_value) actual_args <- actual_args[modifed_args] # look for defaults if not modified in other @@ -150,23 +157,23 @@ check_mode <- function(object, lvl) { get_model_spec <- function(model, mode, engine) { m_env <- get_model_env() - env_obj <- env_names(m_env) + env_obj <- rlang::env_names(m_env) env_obj <- grep(model, env_obj, value = TRUE) res <- list() res$libs <- - env_get(m_env, paste0(model, "_pkgs")) %>% + rlang::env_get(m_env, paste0(model, "_pkgs")) %>% purrr::pluck("pkg") %>% purrr::pluck(1) res$fit <- - env_get(m_env, paste0(model, "_fit")) %>% + rlang::env_get(m_env, paste0(model, "_fit")) %>% dplyr::filter(mode == !!mode & engine == !!engine) %>% dplyr::pull(value) %>% purrr:::pluck(1) pred_code <- - env_get(m_env, paste0(model, "_predict")) %>% + rlang::env_get(m_env, paste0(model, "_predict")) %>% dplyr::filter(mode == !!mode & engine == !!engine) %>% dplyr::select(-engine, -mode) @@ -178,7 +185,7 @@ get_model_spec <- function(model, mode, engine) { get_args <- function(model, engine) { m_env <- get_model_env() - env_get(m_env, paste0(model, "_args")) %>% + rlang::env_get(m_env, paste0(model, "_args")) %>% dplyr::select(-engine) } From de5ea8b64e2cc2c5cb55df5b4c452279eff4106f Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 23 May 2019 17:51:53 -0400 Subject: [PATCH 17/64] added missing engine filter --- R/translate.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/translate.R b/R/translate.R index c357a71b0..d3911e5b4 100644 --- a/R/translate.R +++ b/R/translate.R @@ -186,6 +186,7 @@ get_model_spec <- function(model, mode, engine) { get_args <- function(model, engine) { m_env <- get_model_env() rlang::env_get(m_env, paste0(model, "_args")) %>% + dplyr::filter(engine == !!engine) %>% dplyr::select(-engine) } From 409fdf1628f135d611813bd4a7e598f4468828c4 Mon Sep 17 00:00:00 2001 From: topepo Date: Sat, 25 May 2019 15:45:12 -0400 Subject: [PATCH 18/64] new structure working up to `fit` and `fit_xy` --- NEWS.md | 5 +++++ R/engines.R | 10 ---------- R/fit.R | 12 +++++------- R/translate.R | 51 +++++++++++---------------------------------------- man/fit.Rd | 2 -- 5 files changed, 21 insertions(+), 59 deletions(-) diff --git a/NEWS.md b/NEWS.md index df598bdee..fff23d1e9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # parsnip 0.0.2.9000 +## Breaking Changes + + * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. + * The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation). + ## New Features * `add_rowindex()` can create a column called `.row` to a data frame. diff --git a/R/engines.R b/R/engines.R index 851e46755..327f69dc6 100644 --- a/R/engines.R +++ b/R/engines.R @@ -1,19 +1,9 @@ -get_model_info <- function (x, engine) { - cls <- specific_model(x) - nm <- paste(cls, engine, "data", sep = "_") - res <- try(get(nm), silent = TRUE) - if (inherits(res, "try-error")) - stop("Can't find model object ", nm) - res -} - specific_model <- function(x) { cls <- class(x) cls[cls != "model_spec"] } - possible_engines <- function(object, ...) { m_env <- get_model_env() engs <- rlang::env_get(m_env, specific_model(object)) diff --git a/R/fit.R b/R/fit.R index e5b745b8f..1cb32f300 100644 --- a/R/fit.R +++ b/R/fit.R @@ -52,8 +52,6 @@ #' #' lr_mod <- logistic_reg() #' -#' lr_mod <- logistic_reg() -#' #' using_formula <- #' lr_mod %>% #' set_engine("glm") %>% @@ -125,7 +123,7 @@ fit.model_spec <- ) # populate `method` with the details for this model type - object <- get_method(object, engine = object$engine) + object <- add_methods(object, engine = object$engine) check_installs(object) @@ -215,7 +213,7 @@ fit_xy.model_spec <- ) # populate `method` with the details for this model type - object <- get_method(object, engine = object$engine) + object <- add_methods(object, engine = object$engine) check_installs(object) @@ -239,7 +237,7 @@ fit_xy.model_spec <- ... ), - data.frame_data.frame =, matrix_data.frame = + data.frame_data.frame = , matrix_data.frame = xy_xy( object = object, env = eval_env, @@ -249,7 +247,7 @@ fit_xy.model_spec <- ), # heterogenous combinations - matrix_formula =, data.frame_formula = + matrix_formula = , data.frame_formula = xy_form( object = object, env = eval_env, @@ -367,7 +365,7 @@ check_xy_interface <- function(x, y, cl, model) { print.model_fit <- function(x, ...) { cat("parsnip model object\n\n") - if(inherits(x$fit, "try-error")) { + if (inherits(x$fit, "try-error")) { cat("Model fit failed with error:\n", x$fit, "\n") } else { print(x$fit, ...) diff --git a/R/translate.R b/R/translate.R index d3911e5b4..7fe3e13bf 100644 --- a/R/translate.R +++ b/R/translate.R @@ -41,7 +41,7 @@ #' #' @export -translate <- function (x, ...) +translate <- function(x, ...) UseMethod("translate") #' @importFrom utils getFromNamespace @@ -58,12 +58,10 @@ translate.default <- function(x, engine = x$engine, ...) { x$engine <- engine x <- check_engine(x) - # TODOS - # what to do with unknown mode? if (x$mode == "unknown") { stop("Model code depends on the mode; please specify one.", call. = FALSE) } - # set the classes. Is a constructor not being used? + if (is.null(x$method)) x$method <- get_model_spec(mod_name, x$mode, engine) @@ -78,11 +76,11 @@ translate.default <- function(x, engine = x$engine, ...) { x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original) # keep only modified args - modifed_args <- purrr::map_lgl(actual_args, null_value) + modifed_args <- !purrr::map_lgl(actual_args, null_value) actual_args <- actual_args[modifed_args] # look for defaults if not modified in other - if(length(x$method$fit$defaults) > 0) { + if (length(x$method$fit$defaults) > 0) { in_other <- names(x$method$fit$defaults) %in% names(x$eng_args) x$defaults <- x$method$fit$defaults[!in_other] } @@ -96,40 +94,6 @@ translate.default <- function(x, engine = x$engine, ...) { x } -get_method <- function(x, engine = x$engine, ...) { - check_empty_ellipse(...) - x$engine <- engine - x <- check_engine(x) - x$method <- get_model_info(x, x$engine) - x -} - - -get_module <- function(nm) { - arg_key <- try( - getFromNamespace( - paste0(nm, "_arg_key"), - ns = "parsnip" - ), - silent = TRUE - ) - if(inherits(arg_key, "try-error")) { - arg_key <- try( - get(paste0(nm, "_arg_key")), - silent = TRUE - ) - } - if(inherits(arg_key, "try-error")) { - stop( - "Cannot find the model code: `", - paste0(nm, "_arg_key"), - "`", call. = FALSE - ) - } - arg_key -} - - #' @export print.model_spec <- function(x, ...) { cat("Model Specification (", x$mode, ")\n\n", sep = "") @@ -201,3 +165,10 @@ unionize <- function(args, key) { names(args) <- merged$original args[!is.na(merged$original)] } + +add_methods <- function(x, engine) { + x$engine <- engine + x <- check_engine(x) + x$method <- get_model_spec(specific_model(x), x$mode, x$engine) + x +} diff --git a/man/fit.Rd b/man/fit.Rd index 42a295324..191e60db4 100644 --- a/man/fit.Rd +++ b/man/fit.Rd @@ -88,8 +88,6 @@ data("lending_club") lr_mod <- logistic_reg() -lr_mod <- logistic_reg() - using_formula <- lr_mod \%>\% set_engine("glm") \%>\% From fbd544a3de24db51df7cbd950768eb14b83d9a45 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 26 May 2019 20:29:11 -0400 Subject: [PATCH 19/64] _almost_ all tests pass --- R/aaa.R | 4 ++ R/boost_tree.R | 6 +-- R/decision_tree_data.R | 2 +- R/linear_reg_data.R | 31 +++++++++++--- R/logistic_reg_data.R | 24 +++++------ R/mlp_data.R | 11 ++++- R/multinom_reg_data.R | 8 ++-- R/nearest_neighbor_data.R | 16 ++------ R/nullmodel.R | 1 + R/predict_class.R | 12 +++--- R/predict_classprob.R | 12 +++--- R/predict_interval.R | 28 ++++++------- R/predict_numeric.R | 12 +++--- R/predict_quantile.R | 14 +++---- R/predict_raw.R | 14 +++---- R/rand_forest_data.R | 29 ++----------- R/surv_reg.R | 2 +- R/translate.R | 4 +- tests/testthat/test_boost_tree.R | 2 +- tests/testthat/test_boost_tree_xgboost.R | 2 +- tests/testthat/test_nearest_neighbor.R | 39 +++++++++--------- tests/testthat/test_nearest_neighbor_kknn.R | 10 +++-- tests/testthat/test_rand_forest_ranger.R | 45 +++++++++++---------- tests/testthat/test_surv_reg_survreg.R | 4 +- tests/testthat/test_svm_poly.R | 24 +++++------ tests/testthat/test_svm_rbf.R | 20 ++++----- 26 files changed, 192 insertions(+), 184 deletions(-) diff --git a/R/aaa.R b/R/aaa.R index c8fd5a72e..bb7b4dfa1 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -18,3 +18,7 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { res } +# ------------------------------------------------------------------------------ + +#' @importFrom utils globalVariables +utils::globalVariables(c("value", "engine", "lab", "original", "engine2")) diff --git a/R/boost_tree.R b/R/boost_tree.R index 9c7f84609..b49ac1863 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -387,15 +387,15 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { pred <- xgb_pred(object$fit, newdata = new_data, ntreelimit = tree) # switch based on prediction type - if(object$spec$mode == "regression") { + if (object$spec$mode == "regression") { pred <- tibble(.pred = pred) nms <- names(pred) } else { if (type == "class") { - pred <- boost_tree_xgboost_data$class$post(pred, object) + pred <- object$spec$method$pred$class$post(pred, object) pred <- tibble(.pred = factor(pred, levels = object$lvl)) } else { - pred <- boost_tree_xgboost_data$classprob$post(pred, object) + pred <- object$spec$method$pred$prob$post(pred, object) pred <- as_tibble(pred) names(pred) <- paste0(".pred_", names(pred)) } diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index cd443312a..f319c56dc 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -141,7 +141,7 @@ set_model_arg( mod = "decision_tree", eng = "C5.0", val = "min_n", - original = "minsplit", + original = "minCases", func = list(pkg = "dials", fun = "min_n"), submodels = FALSE ) diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 5e7e22180..334d6e91d 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -209,16 +209,16 @@ set_pred( .pred_lower = convert_stan_interval( results, - level = object$spec$method$confint$extras$level + level = object$spec$method$pred$conf_int$extras$level ), .pred_upper = convert_stan_interval( results, - level = object$spec$method$confint$extras$level, + level = object$spec$method$pred$conf_int$extras$level, lower = FALSE ), ) - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -246,16 +246,16 @@ set_pred( .pred_lower = convert_stan_interval( results, - level = object$spec$method$predint$extras$level + level = object$spec$method$pred$pred_int$extras$level ), .pred_upper = convert_stan_interval( results, - level = object$spec$method$predint$extras$level, + level = object$spec$method$pred$pred_int$extras$level, lower = FALSE ), ) - if(object$spec$method$predint$extras$std_error) + if (object$spec$method$pred$pred_int$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -287,6 +287,25 @@ set_pred( set_model_engine("linear_reg", "regression", "spark") set_dependency("linear_reg", "spark", "sparklyr") +set_model_arg( + mod = "linear_reg", + eng = "spark", + val = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + submodels = TRUE +) + +set_model_arg( + mod = "linear_reg", + eng = "spark", + val = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + submodels = FALSE +) + + set_fit( mod = "linear_reg", eng = "spark", diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 972230e16..c8558d688 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -84,7 +84,7 @@ set_pred( value = list( pre = NULL, post = function(results, object) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) trans <- object$fit$family$linkinv @@ -101,7 +101,7 @@ set_pred( hi_nms <- paste0(".pred_upper_", object$lvl) colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - if (object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res$.std_error <- results$se.fit res }, @@ -214,7 +214,7 @@ set_dependency("logistic_reg", "spark", "sparklyr") set_model_arg( mod = "logistic_reg", - eng = "glmnet", + eng = "spark", val = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), @@ -223,9 +223,9 @@ set_model_arg( set_model_arg( mod = "logistic_reg", - eng = "glmnet", - val = "elastic_net_param", - original = "alpha", + eng = "spark", + val = "mixture", + original = "elastic_net_param", func = list(pkg = "dials", fun = "mixture"), submodels = FALSE ) @@ -439,12 +439,12 @@ set_pred( lo = convert_stan_interval( results, - level = object$spec$method$confint$extras$level + level = object$spec$method$pred$conf_int$extras$level ), hi = convert_stan_interval( results, - level = object$spec$method$confint$extras$level, + level = object$spec$method$pred$conf_int$extras$level, lower = FALSE ), ) @@ -456,7 +456,7 @@ set_pred( hi_nms <- paste0(".pred_upper_", object$lvl) colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - if (object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, @@ -484,12 +484,12 @@ set_pred( lo = convert_stan_interval( results, - level = object$spec$method$predint$extras$level + level = object$spec$method$pred$pred_int$extras$level ), hi = convert_stan_interval( results, - level = object$spec$method$predint$extras$level, + level = object$spec$method$pred$pred_int$extras$level, lower = FALSE ), ) @@ -501,7 +501,7 @@ set_pred( hi_nms <- paste0(".pred_upper_", object$lvl) colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - if (object$spec$method$predint$extras$std_error) + if (object$spec$method$pred$pred_int$extras$std_error) res$.std_error <- apply(results, 2, sd, na.rm = TRUE) res }, diff --git a/R/mlp_data.R b/R/mlp_data.R index 89907fe7c..a707f4cca 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -187,11 +187,18 @@ set_model_arg( mod = "mlp", eng = "nnet", val = "penalty", - original = "penalty", + original = "decay", func = list(pkg = "dials", fun = "weight_decay"), submodels = FALSE ) - +set_model_arg( + mod = "mlp", + eng = "nnet", + val = "epochs", + original = "maxit", + func = list(pkg = "dials", fun = "epochs"), + submodels = FALSE +) set_fit( mod = "mlp", eng = "nnet", diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index e356e36ab..418422c43 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -100,7 +100,7 @@ set_dependency("multinom_reg", "spark", "sparklyr") set_model_arg( mod = "multinom_reg", - eng = "glmnet", + eng = "spark", val = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), @@ -109,9 +109,9 @@ set_model_arg( set_model_arg( mod = "multinom_reg", - eng = "glmnet", - val = "elastic_net_param", - original = "alpha", + eng = "spark", + val = "mixture", + original = "elastic_net_param", func = list(pkg = "dials", fun = "mixture"), submodels = FALSE ) diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 98f086277..3c8edb10a 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -10,14 +10,6 @@ set_model_engine("nearest_neighbor", "classification", "kknn") set_model_engine("nearest_neighbor", "regression", "kknn") set_dependency("nearest_neighbor", "kknn", "kknn") -set_model_arg( - mod = "nearest_neighbor", - eng = "kknn", - val = "num_terms", - original = "nprune", - func = list(pkg = "dials", fun = "num_terms"), - submodels = FALSE -) set_model_arg( mod = "nearest_neighbor", eng = "kknn", @@ -37,8 +29,8 @@ set_model_arg( set_model_arg( mod = "nearest_neighbor", eng = "kknn", - val = "distance", - original = "dist_power", + val = "dist_power", + original = "distance", func = list(pkg = "dials", fun = "distance"), submodels = FALSE ) @@ -49,7 +41,7 @@ set_fit( mode = "regression", value = list( interface = "formula", - protect = c("formula", "data", "kmax"), # kmax is not allowed + protect = c("formula", "data", "ks"), func = c(pkg = "kknn", fun = "train.kknn"), defaults = list() ) @@ -61,7 +53,7 @@ set_fit( mode = "classification", value = list( interface = "formula", - protect = c("formula", "data", "kmax"), # kmax is not allowed + protect = c("formula", "data", "ks"), func = c(pkg = "kknn", fun = "train.kknn"), defaults = list() ) diff --git a/R/nullmodel.R b/R/nullmodel.R index 59f86d051..f20c7751a 100644 --- a/R/nullmodel.R +++ b/R/nullmodel.R @@ -160,6 +160,7 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { #' @export null_model <- function(mode = "classification") { + null_model_modes <- unique(get_model_env()$null_model$mode) # Check for correct mode if (!(mode %in% null_model_modes)) stop("`mode` should be one of: ", diff --git a/R/predict_class.R b/R/predict_class.R index 0e292a8b3..91c84fc02 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -13,7 +13,7 @@ predict_class.model_fit <- function(object, new_data, ...) { stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "class")) + if (!any(names(object$spec$method$pred) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -24,17 +24,17 @@ predict_class.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$class$pre)) - new_data <- object$spec$method$class$pre(new_data, object) + if (!is.null(object$spec$method$pred$class$pre)) + new_data <- object$spec$method$pred$class$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$class) + pred_call <- make_pred_call(object$spec$method$pred$class) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$class$post)) { - res <- object$spec$method$class$post(res, object) + if (!is.null(object$spec$method$pred$class$post)) { + res <- object$spec$method$pred$class$post(res, object) } # coerce levels to those in `object` diff --git a/R/predict_classprob.R b/R/predict_classprob.R index d902e7735..1dbbe5328 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -10,7 +10,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) { stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "classprob")) + if (!any(names(object$spec$method$pred) == "prob")) stop("No class probability module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -21,17 +21,17 @@ predict_classprob.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$classprob$pre)) - new_data <- object$spec$method$classprob$pre(new_data, object) + if (!is.null(object$spec$method$pred$prob$pre)) + new_data <- object$spec$method$pred$prob$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$classprob) + pred_call <- make_pred_call(object$spec$method$pred$prob) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$classprob$post)) { - res <- object$spec$method$classprob$post(res, object) + if (!is.null(object$spec$method$pred$prob$post)) { + res <- object$spec$method$pred$prob$post(res, object) } # check and sort names diff --git a/R/predict_interval.R b/R/predict_interval.R index f38838e00..b492ef866 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -10,7 +10,7 @@ # @export predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$confint)) + if (is.null(object$spec$method$pred$conf_int)) stop("No confidence interval method defined for this ", "engine.", call. = FALSE) @@ -22,19 +22,19 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$confint$pre)) - new_data <- object$spec$method$confint$pre(new_data, object) + if (!is.null(object$spec$method$pred$conf_int$pre)) + new_data <- object$spec$method$pred$conf_int$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$confint$extras <- + object$spec$method$pred$conf_int$extras <- list(level = level, std_error = std_error) - pred_call <- make_pred_call(object$spec$method$confint) + pred_call <- make_pred_call(object$spec$method$pred$conf_int) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$confint$post)) { - res <- object$spec$method$confint$post(res, object) + if (!is.null(object$spec$method$pred$conf_int$post)) { + res <- object$spec$method$pred$conf_int$post(res, object) } attr(res, "level") <- level @@ -59,7 +59,7 @@ predict_confint <- function(object, ...) # @export predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$predint)) + if (is.null(object$spec$method$pred$pred_int)) stop("No prediction interval method defined for this ", "engine.", call. = FALSE) @@ -71,20 +71,20 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$predint$pre)) - new_data <- object$spec$method$predint$pre(new_data, object) + if (!is.null(object$spec$method$pred$pred_int$pre)) + new_data <- object$spec$method$pred$pred_int$pre(new_data, object) # create prediction call # Pass some extra arguments to be used in post-processor - object$spec$method$predint$extras <- + object$spec$method$pred$pred_int$extras <- list(level = level, std_error = std_error) - pred_call <- make_pred_call(object$spec$method$predint) + pred_call <- make_pred_call(object$spec$method$pred$pred_int) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$predint$post)) { - res <- object$spec$method$predint$post(res, object) + if (!is.null(object$spec$method$pred$pred_int$post)) { + res <- object$spec$method$pred$pred_int$post(res, object) } attr(res, "level") <- level diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 3a509546b..970107049 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -11,7 +11,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) { "Use `predict_class()` or `predict_classprob()` for ", "classification models.", call. = FALSE) - if (!any(names(object$spec$method) == "numeric")) + if (!any(names(object$spec$method$pred) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -22,17 +22,17 @@ predict_numeric.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$numeric$pre)) - new_data <- object$spec$method$numeric$pre(new_data, object) + if (!is.null(object$spec$method$pred$numeric$pre)) + new_data <- object$spec$method$pred$numeric$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$numeric) + pred_call <- make_pred_call(object$spec$method$pred$numeric) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$numeric$post)) { - res <- object$spec$method$numeric$post(res, object) + if (!is.null(object$spec$method$pred$numeric$post)) { + res <- object$spec$method$pred$numeric$post(res, object) } if (is.vector(res)) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 698ddb4c8..19d22654f 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -9,7 +9,7 @@ predict_quantile.model_fit <- function (object, new_data, quantile = (1:9)/10, ...) { - if (is.null(object$spec$method$quantile)) + if (is.null(object$spec$method$pred$quantile)) stop("No quantile prediction method defined for this ", "engine.", call. = FALSE) @@ -21,18 +21,18 @@ predict_quantile.model_fit <- new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$quantile$pre)) - new_data <- object$spec$method$quantile$pre(new_data, object) + if (!is.null(object$spec$method$pred$quantile$pre)) + new_data <- object$spec$method$pred$quantile$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$quantile$args$p <- quantile - pred_call <- make_pred_call(object$spec$method$quantile) + object$spec$method$pred$quantile$args$p <- quantile + pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$quantile$post)) { - res <- object$spec$method$quantile$post(res, object) + if(!is.null(object$spec$method$pred$quantile$post)) { + res <- object$spec$method$pred$quantile$post(res, object) } res diff --git a/R/predict_raw.R b/R/predict_raw.R index 315c9dd0a..2d859f8a7 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -4,17 +4,17 @@ # @export predict_raw.model_fit # @export predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { - protected_args <- names(object$spec$method$raw$args) + protected_args <- names(object$spec$method$pred$raw$args) dup_args <- names(opts) %in% protected_args if (any(dup_args)) { opts <- opts[[!dup_args]] } if (length(opts) > 0) { - object$spec$method$raw$args <- - c(object$spec$method$raw$args, opts) + object$spec$method$pred$raw$args <- + c(object$spec$method$pred$raw$args, opts) } - if (!any(names(object$spec$method) == "raw")) + if (!any(names(object$spec$method$pred) == "raw")) stop("No raw prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -25,11 +25,11 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$raw$pre)) - new_data <- object$spec$method$raw$pre(new_data, object) + if (!is.null(object$spec$method$pred$raw$pre)) + new_data <- object$spec$method$pred$raw$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$raw) + pred_call <- make_pred_call(object$spec$method$pred$raw) res <- eval_tidy(pred_call) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 132f5047b..64e47b927 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -1,24 +1,3 @@ - -rand_forest_arg_key <- data.frame( - randomForest = c("mtry", "ntree", "nodesize"), - ranger = c("mtry", "num.trees", "min.node.size"), - spark = - c("feature_subset_strategy", "num_trees", "min_instances_per_node"), - stringsAsFactors = FALSE, - row.names = c("mtry", "trees", "min_n") -) - -rand_forest_modes <- c("classification", "regression", "unknown") - -rand_forest_engines <- data.frame( - ranger = c(TRUE, TRUE, FALSE), - randomForest = c(TRUE, TRUE, FALSE), - spark = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") -) - -# ------------------------------------------------------------------------------ - # wrappers for ranger ranger_class_pred <- function(results, object) { @@ -32,7 +11,7 @@ ranger_class_pred <- #' @importFrom stats qnorm ranger_num_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qnorm(hf_lvl, lower.tail = FALSE) res <- @@ -44,12 +23,12 @@ ranger_num_confint <- function(object, new_data, ...) { res$.pred_upper <- res$.pred + const * std_error res$.pred <- NULL - if (object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res$.std_error <- std_error res } ranger_class_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qnorm(hf_lvl, lower.tail = FALSE) pred <- predict(object$fit, data = new_data, type = "response", ...)$predictions @@ -73,7 +52,7 @@ ranger_class_confint <- function(object, new_data, ...) { col_names <- paste0(c(".pred_lower_", ".pred_upper_"), lvl) res <- res[, col_names] - if (object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res <- bind_cols(res, std_error) res diff --git a/R/surv_reg.R b/R/surv_reg.R index e2f39e14b..db00ce75e 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -171,7 +171,7 @@ check_args.surv_reg <- function(object) { #' @importFrom stats setNames #' @importFrom dplyr mutate survreg_quant <- function(results, object) { - pctl <- object$spec$method$quantile$args$p + pctl <- object$spec$method$pred$quantile$args$p n <- nrow(results) p <- ncol(results) results <- diff --git a/R/translate.R b/R/translate.R index 7fe3e13bf..7de66a2b6 100644 --- a/R/translate.R +++ b/R/translate.R @@ -134,7 +134,7 @@ get_model_spec <- function(model, mode, engine) { rlang::env_get(m_env, paste0(model, "_fit")) %>% dplyr::filter(mode == !!mode & engine == !!engine) %>% dplyr::pull(value) %>% - purrr:::pluck(1) + purrr::pluck(1) pred_code <- rlang::env_get(m_env, paste0(model, "_predict")) %>% @@ -156,6 +156,8 @@ get_args <- function(model, engine) { # to replace harmonize unionize <- function(args, key) { + if (length(args) == 0) + return(args) parsn <- tibble(parsnip = names(args), order = seq_along(args)) merged <- dplyr::left_join(parsn, key, by = "parsnip") %>% diff --git a/tests/testthat/test_boost_tree.R b/tests/testthat/test_boost_tree.R index a159387fd..c37b05cb4 100644 --- a/tests/testthat/test_boost_tree.R +++ b/tests/testthat/test_boost_tree.R @@ -127,7 +127,7 @@ test_that('bad input', { bt <- boost_tree(min_n = -10) fit(bt, Species ~ ., iris) %>% set_engine("xgboost") }) - expect_message(translate(boost_tree(), engine = NULL)) + expect_message(translate(boost_tree(mode = "classification"), engine = NULL)) expect_error(translate(boost_tree(formula = y ~ x))) }) diff --git a/tests/testthat/test_boost_tree_xgboost.R b/tests/testthat/test_boost_tree_xgboost.R index 805e0bfca..ed65bcaad 100644 --- a/tests/testthat/test_boost_tree_xgboost.R +++ b/tests/testthat/test_boost_tree_xgboost.R @@ -8,7 +8,7 @@ context("boosted tree execution with xgboost") num_pred <- names(iris)[1:4] iris_xgboost <- - boost_tree(trees = 2) %>% + boost_tree(trees = 2, mode = "classification") %>% set_engine("xgboost") ctrl <- fit_control(verbosity = 1, catch = FALSE) diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 797c55d0f..68fcb5a7c 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -10,53 +10,52 @@ source("helpers.R") # ------------------------------------------------------------------------------ test_that('primary arguments', { - basic <- nearest_neighbor() - basic_kknn <- translate(basic %>% set_engine( "kknn")) + basic <- nearest_neighbor(mode = "regression") + basic_kknn <- translate(basic %>% set_engine("kknn")) expect_equal( object = basic_kknn$method$fit$args, expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - kmax = expr(missing_arg()) + ks = expr(missing_arg()) ) ) - neighbors <- nearest_neighbor(neighbors = 5) - neighbors_kknn <- translate(neighbors %>% set_engine( "kknn")) + neighbors <- nearest_neighbor(mode = "classification", neighbors = 5) + neighbors_kknn <- translate(neighbors %>% set_engine("kknn")) expect_equal( object = neighbors_kknn$method$fit$args, expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - kmax = expr(missing_arg()), ks = new_empty_quosure(5) ) ) - weight_func <- nearest_neighbor(weight_func = "triangular") - weight_func_kknn <- translate(weight_func %>% set_engine( "kknn")) + weight_func <- nearest_neighbor(mode = "classification", weight_func = "triangular") + weight_func_kknn <- translate(weight_func %>% set_engine("kknn")) expect_equal( object = weight_func_kknn$method$fit$args, expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - kmax = expr(missing_arg()), + ks = expr(missing_arg()), kernel = new_empty_quosure("triangular") ) ) - dist_power <- nearest_neighbor(dist_power = 2) - dist_power_kknn <- translate(dist_power %>% set_engine( "kknn")) + dist_power <- nearest_neighbor(mode = "classification", dist_power = 2) + dist_power_kknn <- translate(dist_power %>% set_engine("kknn")) expect_equal( object = dist_power_kknn$method$fit$args, expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - kmax = expr(missing_arg()), + ks = expr(missing_arg()), distance = new_empty_quosure(2) ) ) @@ -65,14 +64,14 @@ test_that('primary arguments', { test_that('engine arguments', { - kknn_scale <- nearest_neighbor() %>% set_engine( "kknn", scale = FALSE) + kknn_scale <- nearest_neighbor(mode = "classification") %>% set_engine("kknn", scale = FALSE) expect_equal( object = translate(kknn_scale, "kknn")$method$fit$args, expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - kmax = expr(missing_arg()), + ks = expr(missing_arg()), scale = new_empty_quosure(FALSE) ) ) @@ -82,14 +81,14 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- nearest_neighbor() %>% set_engine( "kknn", scale = FALSE) - expr1_exp <- nearest_neighbor(neighbors = 5) %>% set_engine( "kknn", scale = FALSE) + expr1 <- nearest_neighbor() %>% set_engine("kknn", scale = FALSE) + expr1_exp <- nearest_neighbor(neighbors = 5) %>% set_engine("kknn", scale = FALSE) - expr2 <- nearest_neighbor(neighbors = varying()) %>% set_engine( "kknn") - expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") %>% set_engine( "kknn") + expr2 <- nearest_neighbor(neighbors = varying()) %>% set_engine("kknn") + expr2_exp <- nearest_neighbor(neighbors = varying(), weight_func = "triangular") %>% set_engine("kknn") - expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine( "kknn") - expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine( "kknn") + expr3 <- nearest_neighbor(neighbors = 2, weight_func = varying()) %>% set_engine("kknn") + expr3_exp <- nearest_neighbor(neighbors = 3) %>% set_engine("kknn") expect_equal(update(expr1, neighbors = 5), expr1_exp) expect_equal(update(expr2, weight_func = "triangular"), expr2_exp) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index fa153d12f..1418c50ae 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -8,7 +8,9 @@ context("nearest neighbor execution with kknn") num_pred <- c("Sepal.Width", "Petal.Width", "Petal.Length") iris_bad_form <- as.formula(Species ~ term) -iris_basic <- nearest_neighbor(neighbors = 8, weight_func = "triangular") %>% +iris_basic <- nearest_neighbor(mode = "classification", + neighbors = 8, + weight_func = "triangular") %>% set_engine("kknn") ctrl <- fit_control(verbosity = 1, catch = FALSE) @@ -31,7 +33,7 @@ test_that('kknn execution', { x = iris[, num_pred], y = iris$Sepal.Length ), - regexp = NA + regexp = "outcome should be a factor" ) # nominal @@ -67,7 +69,7 @@ test_that('kknn prediction', { iris_basic, control = ctrl, x = iris[, num_pred], - y = iris$Sepal.Length + y = iris$Species ) uni_pred <- predict( @@ -75,7 +77,7 @@ test_that('kknn prediction', { newdata = iris[1:5, num_pred] ) - expect_equal(uni_pred, predict(res_xy, iris[1:5, num_pred])$.pred) + expect_equal(tibble(.pred_class = uni_pred), predict(res_xy, iris[1:5, num_pred])) # nominal res_xy_nom <- fit_xy( diff --git a/tests/testthat/test_rand_forest_ranger.R b/tests/testthat/test_rand_forest_ranger.R index 5e8300400..3523613e7 100644 --- a/tests/testthat/test_rand_forest_ranger.R +++ b/tests/testthat/test_rand_forest_ranger.R @@ -13,11 +13,11 @@ data("lending_club") lending_club <- head(lending_club, 200) num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") -lc_basic <- rand_forest() %>% set_engine("ranger") -lc_ranger <- rand_forest() %>% set_engine("ranger", seed = 144) +lc_basic <- rand_forest(mode = "classification") %>% set_engine("ranger") +lc_ranger <- rand_forest(mode = "classification") %>% set_engine("ranger", seed = 144) -bad_ranger_cls <- rand_forest() %>% set_engine("ranger", replace = "bad") -bad_rf_cls <- rand_forest() %>% set_engine("ranger", sampsize = -10) +bad_ranger_cls <- rand_forest(mode = "classification") %>% set_engine("ranger", replace = "bad") +bad_rf_cls <- rand_forest(mode = "classification") %>% set_engine("ranger", sampsize = -10) ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) @@ -34,7 +34,6 @@ test_that('ranger classification execution', { lc_ranger, Class ~ funded_amnt + term, data = lending_club, - control = ctrl ), regexp = NA @@ -63,7 +62,7 @@ test_that('ranger classification execution', { ranger_form_catch <- fit( bad_ranger_cls, - funded_amnt ~ term, + Class ~ term, data = lending_club, control = caught_ctrl @@ -75,7 +74,7 @@ test_that('ranger classification execution', { control = caught_ctrl, x = lending_club[, num_pred], - y = lending_club$total_bal_il + y = lending_club$Class ) expect_true(inherits(ranger_xy_catch$fit, "try-error")) @@ -159,7 +158,7 @@ test_that('ranger classification probabilities', { ) no_prob_model <- fit_xy( - rand_forest() %>% set_engine("ranger", probability = FALSE), + rand_forest(mode = "classification") %>% set_engine("ranger", probability = FALSE), x = lending_club[, num_pred], y = lending_club$Class, control = ctrl @@ -174,10 +173,10 @@ test_that('ranger classification probabilities', { num_pred <- names(mtcars)[3:6] -car_basic <- rand_forest() %>% set_engine("ranger") +car_basic <- rand_forest(mode = "regression") %>% set_engine("ranger") -bad_ranger_reg <- rand_forest() %>% set_engine("ranger", replace = "bad") -bad_rf_reg <- rand_forest() %>% set_engine("ranger", sampsize = -10) +bad_ranger_reg <- rand_forest(mode = "regression") %>% set_engine("ranger", replace = "bad") +bad_rf_reg <- rand_forest(mode = "regression") %>% set_engine("ranger", sampsize = -10) ctrl <- list(verbosity = 1, catch = FALSE) caught_ctrl <- list(verbosity = 1, catch = TRUE) @@ -251,7 +250,7 @@ test_that('ranger regression intervals', { skip_if_not_installed("ranger") xy_fit <- fit_xy( - rand_forest() %>% set_engine("ranger", keep.inbag = TRUE), + rand_forest(mode = "regression") %>% set_engine("ranger", keep.inbag = TRUE), x = mtcars[, -1], y = mtcars$mpg, control = ctrl @@ -282,7 +281,8 @@ test_that('additional descriptor tests', { skip_if_not_installed("ranger") descr_xy <- fit_xy( - rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), + rand_forest(mode = "regression", mtry = floor(sqrt(.cols())) + 1) %>% + set_engine("ranger"), x = mtcars[, -1], y = mtcars$mpg, control = ctrl @@ -290,14 +290,16 @@ test_that('additional descriptor tests', { expect_equal(descr_xy$fit$mtry, 4) descr_f <- fit( - rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), + rand_forest(mode = "regression", mtry = floor(sqrt(.cols())) + 1) %>% + set_engine("ranger"), mpg ~ ., data = mtcars, control = ctrl ) expect_equal(descr_f$fit$mtry, 4) descr_xy <- fit_xy( - rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), + rand_forest(mode = "regression", mtry = floor(sqrt(.cols())) + 1) %>% + set_engine("ranger"), x = mtcars[, -1], y = mtcars$mpg, control = ctrl @@ -305,7 +307,8 @@ test_that('additional descriptor tests', { expect_equal(descr_xy$fit$mtry, 4) descr_f <- fit( - rand_forest(mtry = floor(sqrt(.cols())) + 1) %>% set_engine("ranger"), + rand_forest(mode = "regression", mtry = floor(sqrt(.cols())) + 1) %>% + set_engine("ranger"), mpg ~ ., data = mtcars, control = ctrl ) @@ -316,7 +319,7 @@ test_that('additional descriptor tests', { exp_wts <- quo(c(min(.lvls()), 20, 10)) descr_other_xy <- fit_xy( - rand_forest(mtry = 2) %>% + rand_forest(mode = "classification", mtry = 2) %>% set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), x = iris[, 1:4], y = iris$Species, @@ -326,7 +329,7 @@ test_that('additional descriptor tests', { expect_equal(descr_other_xy$fit$call$class.weights, exp_wts) descr_other_f <- fit( - rand_forest(mtry = 2) %>% + rand_forest(mode = "classification", mtry = 2) %>% set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), Species ~ ., data = iris, control = ctrl @@ -335,7 +338,7 @@ test_that('additional descriptor tests', { expect_equal(descr_other_f$fit$call$class.weights, exp_wts) descr_other_xy <- fit_xy( - rand_forest(mtry = 2) %>% + rand_forest(mode = "classification", mtry = 2) %>% set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), x = iris[, 1:4], y = iris$Species, @@ -345,7 +348,7 @@ test_that('additional descriptor tests', { expect_equal(descr_other_xy$fit$call$class.weights, exp_wts) descr_other_f <- fit( - rand_forest(mtry = 2) %>% + rand_forest(mode = "classification", mtry = 2) %>% set_engine("ranger", class.weights = c(min(.lvls()), 20, 10)), Species ~ ., data = iris, control = ctrl @@ -410,7 +413,7 @@ test_that('ranger classification intervals', { skip_if_not_installed("ranger") lc_fit <- fit( - rand_forest() %>% + rand_forest(mode = "classification") %>% set_engine("ranger", keep.inbag = TRUE, probability = TRUE), Class ~ funded_amnt + int_rate, data = lending_club, diff --git a/tests/testthat/test_surv_reg_survreg.R b/tests/testthat/test_surv_reg_survreg.R index 5a8e0a66e..6b66b4134 100644 --- a/tests/testthat/test_surv_reg_survreg.R +++ b/tests/testthat/test_surv_reg_survreg.R @@ -8,8 +8,8 @@ library(tibble) basic_form <- Surv(time, status) ~ group complete_form <- Surv(time) ~ group -surv_basic <- surv_reg() %>% set_engine("survreg") -surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survreg") +surv_basic <- surv_reg() %>% set_engine("survival") +surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival") ctrl <- fit_control(verbosity = 1, catch = FALSE) caught_ctrl <- fit_control(verbosity = 1, catch = TRUE) diff --git a/tests/testthat/test_svm_poly.R b/tests/testthat/test_svm_poly.R index 1835a3f5a..bb5b2a65c 100644 --- a/tests/testthat/test_svm_poly.R +++ b/tests/testthat/test_svm_poly.R @@ -11,7 +11,7 @@ source("helpers.R") # ------------------------------------------------------------------------------ test_that('primary arguments', { - basic <- svm_poly() + basic <- svm_poly(mode = "regression") basic_kernlab <- translate(basic %>% set_engine("kernlab")) expect_equal( @@ -23,7 +23,7 @@ test_that('primary arguments', { ) ) - degree <- svm_poly(degree = 2) + degree <- svm_poly(mode = "regression", degree = 2) degree_kernlab <- translate(degree %>% set_engine("kernlab")) degree_obj <- expr(list()) degree_obj$degree <- new_empty_quosure(2) @@ -38,7 +38,7 @@ test_that('primary arguments', { ) ) - degree_scale <- svm_poly(degree = 2, scale_factor = 1.2) + degree_scale <- svm_poly(mode = "regression", degree = 2, scale_factor = 1.2) degree_scale_kernlab <- translate(degree_scale %>% set_engine("kernlab")) degree_scale_obj <- expr(list()) degree_scale_obj$degree <- new_empty_quosure(2) @@ -58,7 +58,7 @@ test_that('primary arguments', { test_that('engine arguments', { - kernlab_cv <- svm_poly() %>% set_engine("kernlab", cross = 10) + kernlab_cv <- svm_poly(mode = "regression") %>% set_engine("kernlab", cross = 10) expect_equal( object = translate(kernlab_cv, "kernlab")$method$fit$args, @@ -75,14 +75,14 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- svm_poly() %>% set_engine("kernlab", cross = 10) - expr1_exp <- svm_poly(degree = 1) %>% set_engine("kernlab", cross = 10) + expr1 <- svm_poly(mode = "regression") %>% set_engine("kernlab", cross = 10) + expr1_exp <- svm_poly(mode = "regression", degree = 1) %>% set_engine("kernlab", cross = 10) - expr2 <- svm_poly(degree = varying()) %>% set_engine("kernlab") - expr2_exp <- svm_poly(degree = varying(), scale_factor = 1) %>% set_engine("kernlab") + expr2 <- svm_poly(mode = "regression", degree = varying()) %>% set_engine("kernlab") + expr2_exp <- svm_poly(mode = "regression", degree = varying(), scale_factor = 1) %>% set_engine("kernlab") - expr3 <- svm_poly(degree = 2, scale_factor = varying()) %>% set_engine("kernlab") - expr3_exp <- svm_poly(degree = 3) %>% set_engine("kernlab") + expr3 <- svm_poly(mode = "regression", degree = 2, scale_factor = varying()) %>% set_engine("kernlab") + expr3_exp <- svm_poly(mode = "regression", degree = 3) %>% set_engine("kernlab") expect_equal(update(expr1, degree = 1), expr1_exp) expect_equal(update(expr2, scale_factor = 1), expr2_exp) @@ -97,12 +97,12 @@ test_that('bad input', { # ------------------------------------------------------------------------------ reg_mod <- - svm_poly(degree = 1, cost = 1/4) %>% + svm_poly(mode = "regression", degree = 1, cost = 1/4) %>% set_engine("kernlab") %>% set_mode("regression") cls_mod <- - svm_poly(degree = 2, cost = 1/8) %>% + svm_poly(mode = "classification", degree = 2, cost = 1/8) %>% set_engine("kernlab") %>% set_mode("classification") diff --git a/tests/testthat/test_svm_rbf.R b/tests/testthat/test_svm_rbf.R index ba78f284b..41523e0a7 100644 --- a/tests/testthat/test_svm_rbf.R +++ b/tests/testthat/test_svm_rbf.R @@ -10,7 +10,7 @@ source("helpers.R") # ------------------------------------------------------------------------------ test_that('primary arguments', { - basic <- svm_rbf() + basic <- svm_rbf(mode = "regression") basic_kernlab <- translate(basic %>% set_engine("kernlab")) expect_equal( @@ -22,7 +22,7 @@ test_that('primary arguments', { ) ) - rbf_sigma <- svm_rbf(rbf_sigma = .2) + rbf_sigma <- svm_rbf(mode = "regression", rbf_sigma = .2) rbf_sigma_kernlab <- translate(rbf_sigma %>% set_engine("kernlab")) rbf_sigma_obj <- expr(list()) rbf_sigma_obj$sigma <- new_empty_quosure(.2) @@ -41,7 +41,7 @@ test_that('primary arguments', { test_that('engine arguments', { - kernlab_cv <- svm_rbf() %>% set_engine("kernlab", cross = 10) + kernlab_cv <- svm_rbf(mode = "regression") %>% set_engine("kernlab", cross = 10) expect_equal( object = translate(kernlab_cv, "kernlab")$method$fit$args, @@ -58,11 +58,11 @@ test_that('engine arguments', { test_that('updating', { - expr1 <- svm_rbf() %>% set_engine("kernlab", cross = 10) - expr1_exp <- svm_rbf(rbf_sigma = .1) %>% set_engine("kernlab", cross = 10) + expr1 <- svm_rbf(mode = "regression") %>% set_engine("kernlab", cross = 10) + expr1_exp <- svm_rbf(mode = "regression", rbf_sigma = .1) %>% set_engine("kernlab", cross = 10) - expr3 <- svm_rbf(rbf_sigma = .2) %>% set_engine("kernlab") - expr3_exp <- svm_rbf(rbf_sigma = .3) %>% set_engine("kernlab") + expr3 <- svm_rbf(mode = "regression", rbf_sigma = .2) %>% set_engine("kernlab") + expr3_exp <- svm_rbf(mode = "regression", rbf_sigma = .3) %>% set_engine("kernlab") expect_equal(update(expr1, rbf_sigma = .1), expr1_exp) expect_equal(update(expr3, rbf_sigma = .3, fresh = TRUE), expr3_exp) @@ -70,18 +70,18 @@ test_that('updating', { test_that('bad input', { expect_error(svm_rbf(mode = "reallyunknown")) - expect_error(translate(svm_rbf() %>% set_engine( NULL))) + expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL))) }) # ------------------------------------------------------------------------------ reg_mod <- - svm_rbf(rbf_sigma = .1, cost = 1/4) %>% + svm_rbf(mode = "regression", rbf_sigma = .1, cost = 1/4) %>% set_engine("kernlab") %>% set_mode("regression") cls_mod <- - svm_rbf(rbf_sigma = .1, cost = 1/8) %>% + svm_rbf(mode = "classification", rbf_sigma = .1, cost = 1/8) %>% set_engine("kernlab") %>% set_mode("classification") From edcddce9adde9326614b39616fbb312943cc5760 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 26 May 2019 20:49:53 -0400 Subject: [PATCH 20/64] removed or consolidated functions --- R/aaa.R | 5 ++++- R/aaa_spark_helpers.R | 3 --- R/arguments.R | 43 ------------------------------------------- R/boost_tree.R | 5 ----- R/fit_helpers.R | 6 ------ R/logistic_reg.R | 5 ----- R/misc.R | 8 -------- R/multinom_reg.R | 4 ---- R/surv_reg.R | 5 ----- R/translate.R | 4 ++-- 10 files changed, 6 insertions(+), 82 deletions(-) diff --git a/R/aaa.R b/R/aaa.R index bb7b4dfa1..b91d6f4c5 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -21,4 +21,7 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { # ------------------------------------------------------------------------------ #' @importFrom utils globalVariables -utils::globalVariables(c("value", "engine", "lab", "original", "engine2")) +utils::globalVariables( + c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', + 'lab', 'original', 'predicted_label', 'prediction', 'value') + ) diff --git a/R/aaa_spark_helpers.R b/R/aaa_spark_helpers.R index 4d233a760..3eba67416 100644 --- a/R/aaa_spark_helpers.R +++ b/R/aaa_spark_helpers.R @@ -20,6 +20,3 @@ format_spark_num <- function(results, object) { results <- dplyr::rename(results, pred = prediction) results } - -#' @importFrom utils globalVariables -utils::globalVariables(c(".", "predicted_label", "prediction")) diff --git a/R/arguments.R b/R/arguments.R index 1ece4725e..39246d762 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -7,49 +7,6 @@ null_value <- function(x) { res } -deharmonize <- function(args, key, engine) { - nms <- names(args) - orig_names <- key[nms, engine] - names(args) <- orig_names - args[!is.na(orig_names)] -} - -parse_engine_options <- function(x) { - res <- ll() - if (length(x) >= 2) { # in case of NULL - - arg_names <- names(x[[2]]) - arg_names <- arg_names[arg_names != ""] - - if (length(arg_names) > 0) { - # in case of list() - res <- ll() - for (i in arg_names) { - res[[i]] <- x[[2]][[i]] - } # over arg_names - } # length == 0 - } - res -} - -prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) { - nms <- names(x) - if (length(whitelist) > 0) - nms <- nms[!(nms %in% whitelist)] - for (i in nms) { - if ( - is.null(x[[i]]) | - is_null(x[[i]]) | - !(i %in% modified) | - is_missing(x[[i]]) - ) - x[[i]] <- NULL - } - if(any(names(x) == "...")) - x["..."] <- NULL - x -} - check_eng_args <- function(args, obj, core_args) { # Make sure that we are not trying to modify an argument that # is explicitly protected in the method metadata or arg_key diff --git a/R/boost_tree.R b/R/boost_tree.R index b49ac1863..b1cdca096 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -503,8 +503,3 @@ C50_by_tree <- function(tree, object, new_data, type, ...) { pred[, c(".row", "trees", nms)] } - -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c(".row")) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 74f614ede..940daebe7 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -181,9 +181,3 @@ xy_form <- function(object, env, control, ...) { res } -# ------------------------------------------------------------------------------ -## - - -#' @importFrom utils globalVariables -utils::globalVariables("data") diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 62e6f22c1..8365c729a 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -355,8 +355,3 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) { predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } - -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c("group", ".pred")) diff --git a/R/misc.R b/R/misc.R index e04842c02..2e2830429 100644 --- a/R/misc.R +++ b/R/misc.R @@ -132,14 +132,6 @@ make_call <- function(fun, ns, args, ...) { out } -resolve_args <- function(args, ...) { - for (i in seq(along = args)) { - if (!is_missing_arg(args[[i]])) - args[[i]] <- eval_tidy(args[[i]], ...) - } - args -} - levels_from_formula <- function(f, dat) { if (inherits(dat, "tbl_spark")) res <- NULL diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 3ff94c1c2..2c978e9ff 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -322,7 +322,3 @@ check_glmnet_lambda <- function(dat, object) { dat } -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c("group", ".pred")) diff --git a/R/surv_reg.R b/R/surv_reg.R index db00ce75e..fa0187a10 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -208,8 +208,3 @@ flexsurv_quant <- function(results, object) { results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper")) } -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(".label") - diff --git a/R/translate.R b/R/translate.R index 7de66a2b6..3522544cc 100644 --- a/R/translate.R +++ b/R/translate.R @@ -68,7 +68,7 @@ translate.default <- function(x, engine = x$engine, ...) { arg_key <- get_args(mod_name, engine) # deharmonize primary arguments - actual_args <- unionize(x$args, arg_key) + actual_args <- deharmonize(x$args, arg_key) # check secondary arguments to see if they are in the final # expression unless there are dots, warn if protected args are @@ -155,7 +155,7 @@ get_args <- function(model, engine) { } # to replace harmonize -unionize <- function(args, key) { +deharmonize <- function(args, key) { if (length(args) == 0) return(args) parsn <- tibble(parsnip = names(args), order = seq_along(args)) From d6e80c7c11bec934327e0d9faf776fa53789924c Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 26 May 2019 21:25:23 -0400 Subject: [PATCH 21/64] some documentation --- R/aaa_models.R | 88 ++++++++++++++++++++++++++++++++++++-------- man/check_mod_val.Rd | 54 +++++++++++++++++++++++++++ man/get_model_env.Rd | 45 ++++++++++++++++++++++ 3 files changed, 172 insertions(+), 15 deletions(-) create mode 100644 man/check_mod_val.Rd create mode 100644 man/get_model_env.Rd diff --git a/R/aaa_models.R b/R/aaa_models.R index d8e340b70..1e5853049 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -1,4 +1,4 @@ -# Initialize model environment +# Initialize model environments # ------------------------------------------------------------------------------ @@ -35,6 +35,9 @@ pred_types <- # ------------------------------------------------------------------------------ +#' Tools to Register Models +#' +#' @keywords internal #' @export get_model_env <- function() { current <- utils::getFromNamespace("parsnip", ns = "parsnip") @@ -42,13 +45,31 @@ get_model_env <- function() { current } + + +#' Tools to Check Model Elements +#' +#' These functions are similar to constructors and can be used to validate +#' that there are no conflicts with the underlying model structures used by the +#' package. +#' +#' @param mod A single character string for the model type (e.g. +#' `"rand_forest"`, etc). +#' @param new A single logical to check to see if the model that you are check +#' has not already been registered. +#' @param existence A single logical to check to see if the model has already +#' been registered. +#' @param mode A single character string for the model mode (e.g. "regression"). +#' @param eng A single character string for the model engine. +#' @param arg A single character string for the model argument name. +#' @keywords internal #' @export -check_mod_val <- function(mod, new = FALSE, existance = FALSE) { +check_mod_val <- function(mod, new = FALSE, existence = FALSE) { if (is_missing(mod) || length(mod) != 1) stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", call. = FALSE) - if (new | existance) { + if (new | existence) { current <- get_model_env() } @@ -58,7 +79,7 @@ check_mod_val <- function(mod, new = FALSE, existance = FALSE) { } } - if (existance) { + if (existence) { current <- get_model_env() if (!any(current$models == mod)) { stop("Model `", mod, "` has not been registered.", call. = FALSE) @@ -68,6 +89,8 @@ check_mod_val <- function(mod, new = FALSE, existance = FALSE) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_mode_val <- function(mode) { if (is_missing(mode) || length(mode) != 1) @@ -76,6 +99,8 @@ check_mode_val <- function(mode) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_engine_val <- function(eng) { if (is_missing(eng) || length(eng) != 1) @@ -84,6 +109,8 @@ check_engine_val <- function(eng) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_arg_val <- function(arg) { if (is_missing(arg) || length(arg) != 1) @@ -92,6 +119,8 @@ check_arg_val <- function(arg) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_submodels_val <- function(x) { if (!is.logical(x) || length(x) != 1) { @@ -100,6 +129,8 @@ check_submodels_val <- function(x) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_func_val <- function(func) { msg <- @@ -135,6 +166,8 @@ check_func_val <- function(func) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_fit_info <- function(x) { if (is.null(x)) { @@ -167,6 +200,8 @@ check_fit_info <- function(x) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal #' @export check_pred_info <- function(x, type) { if (all(type != pred_types)) { @@ -200,7 +235,8 @@ check_pred_info <- function(x, type) { invisible(NULL) } - +#' @rdname check_mod_val +#' @keywords internal #' @export check_pkg_val <- function(x) { if (is_missing(x) || length(x) != 1 || !is.character(x)) @@ -211,6 +247,8 @@ check_pkg_val <- function(x) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_new_model <- function(mod) { check_mod_val(mod, new = TRUE) @@ -247,9 +285,11 @@ set_new_model <- function(mod) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_model_mode <- function(mod, mode) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) check_mode_val(mode) current <- get_model_env() @@ -265,9 +305,11 @@ set_model_mode <- function(mod, mode) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_model_engine <- function(mod, mode, eng) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) check_mode_val(mode) check_mode_val(eng) @@ -288,9 +330,11 @@ set_model_engine <- function(mod, mode, eng) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_model_arg <- function(mod, eng, val, original, func, submodels) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) check_arg_val(val) check_arg_val(original) check_func_val(func) @@ -325,9 +369,11 @@ set_model_arg <- function(mod, eng, val, original, func, submodels) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_dependency <- function(mod, eng, pkg) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) check_pkg_val(pkg) current <- get_model_env() @@ -366,9 +412,11 @@ set_dependency <- function(mod, eng, pkg) { invisible(NULL) } +#' @rdname get_model_env +#' @keywords internal #' @export get_dependency <- function(mod) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) pkg_name <- paste0(mod, "_pkgs") if (!any(pkg_name != rlang::env_names(get_model_env()))) { stop("`", mod, "` does not have a dependency list in parsnip.", call. = FALSE) @@ -379,9 +427,11 @@ get_dependency <- function(mod) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_fit <- function(mod, mode, eng, value) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) check_mode_val(mode) check_engine_val(eng) check_fit_info(value) @@ -428,9 +478,11 @@ set_fit <- function(mod, mode, eng, value) { invisible(NULL) } +#' @rdname get_model_env +#' @keywords internal #' @export get_fit <- function(mod) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) fit_name <- paste0(mod, "_fit") if (!any(fit_name != rlang::env_names(get_model_env()))) { stop("`", mod, "` does not have a `fit` method in parsnip.", call. = FALSE) @@ -440,9 +492,11 @@ get_fit <- function(mod) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export set_pred <- function(mod, mode, eng, type, value) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) check_mode_val(mode) check_engine_val(eng) check_pred_info(value, type) @@ -490,9 +544,11 @@ set_pred <- function(mod, mode, eng, type, value) { invisible(NULL) } +#' @rdname get_model_env +#' @keywords internal #' @export get_pred_type <- function(mod, type) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) pred_name <- paste0(mod, "_predict") if (!any(pred_name != rlang::env_names(get_model_env()))) { stop("`", mod, "` does not have any `pred` methods in parsnip.", call. = FALSE) @@ -514,9 +570,11 @@ validate_model <- function(mod) { # ------------------------------------------------------------------------------ +#' @rdname get_model_env +#' @keywords internal #' @export show_model_info <- function(mod) { - check_mod_val(mod, existance = TRUE) + check_mod_val(mod, existence = TRUE) current <- get_model_env() cat("Information for `", mod, "`\n", sep = "") diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd new file mode 100644 index 000000000..585ee385d --- /dev/null +++ b/man/check_mod_val.Rd @@ -0,0 +1,54 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa_models.R +\name{check_mod_val} +\alias{check_mod_val} +\alias{check_mode_val} +\alias{check_engine_val} +\alias{check_arg_val} +\alias{check_submodels_val} +\alias{check_func_val} +\alias{check_fit_info} +\alias{check_pred_info} +\alias{check_pkg_val} +\title{Tools to Check Model Elements} +\usage{ +check_mod_val(mod, new = FALSE, existence = FALSE) + +check_mode_val(mode) + +check_engine_val(eng) + +check_arg_val(arg) + +check_submodels_val(x) + +check_func_val(func) + +check_fit_info(x) + +check_pred_info(x, type) + +check_pkg_val(x) +} +\arguments{ +\item{mod}{A single character string for the model type (e.g. +\code{"rand_forest"}, etc).} + +\item{new}{A single logical to check to see if the model that you are check +has not already been registered.} + +\item{existence}{A single logical to check to see if the model has already +been registered.} + +\item{mode}{A single character string for the model mode (e.g. "regression").} + +\item{eng}{A single character string for the model engine.} + +\item{arg}{A single character string for the model argument name.} +} +\description{ +These functions are similar to constructors and can be used to validate +that there are no conflicts with the underlying model structures used by the +package. +} +\keyword{internal} diff --git a/man/get_model_env.Rd b/man/get_model_env.Rd new file mode 100644 index 000000000..d434d8957 --- /dev/null +++ b/man/get_model_env.Rd @@ -0,0 +1,45 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa_models.R +\name{get_model_env} +\alias{get_model_env} +\alias{set_new_model} +\alias{set_model_mode} +\alias{set_model_engine} +\alias{set_model_arg} +\alias{set_dependency} +\alias{get_dependency} +\alias{set_fit} +\alias{get_fit} +\alias{set_pred} +\alias{get_pred_type} +\alias{show_model_info} +\title{Tools to Register Models} +\usage{ +get_model_env() + +set_new_model(mod) + +set_model_mode(mod, mode) + +set_model_engine(mod, mode, eng) + +set_model_arg(mod, eng, val, original, func, submodels) + +set_dependency(mod, eng, pkg) + +get_dependency(mod) + +set_fit(mod, mode, eng, value) + +get_fit(mod) + +set_pred(mod, mode, eng, type, value) + +get_pred_type(mod, type) + +show_model_info(mod) +} +\description{ +Tools to Register Models +} +\keyword{internal} From 91d1296774a3207fda7b0f2f128f0d18077625a7 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 May 2019 04:53:14 -0400 Subject: [PATCH 22/64] changed `mod` argument to `model` --- R/aaa_models.R | 138 +++++++++++++++++++------------------- R/boost_tree_data.R | 66 +++++++++--------- R/decision_tree_data.R | 44 ++++++------ R/linear_reg_data.R | 42 ++++++------ R/logistic_reg_data.R | 52 +++++++------- R/mars_data.R | 20 +++--- R/misc.R | 4 +- R/mlp_data.R | 44 ++++++------ R/multinom_reg_data.R | 30 ++++----- R/nearest_neighbor_data.R | 20 +++--- R/nullmodel_data.R | 14 ++-- R/rand_forest_data.R | 56 ++++++++-------- R/surv_reg_data.R | 16 ++--- R/svm_poly_data.R | 22 +++--- R/svm_rbf_data.R | 20 +++--- man/check_mod_val.Rd | 4 +- man/get_model_env.Rd | 22 +++--- 17 files changed, 307 insertions(+), 307 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 1e5853049..be79fe8f1 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -53,7 +53,7 @@ get_model_env <- function() { #' that there are no conflicts with the underlying model structures used by the #' package. #' -#' @param mod A single character string for the model type (e.g. +#' @param model A single character string for the model type (e.g. #' `"rand_forest"`, etc). #' @param new A single logical to check to see if the model that you are check #' has not already been registered. @@ -64,8 +64,8 @@ get_model_env <- function() { #' @param arg A single character string for the model argument name. #' @keywords internal #' @export -check_mod_val <- function(mod, new = FALSE, existence = FALSE) { - if (is_missing(mod) || length(mod) != 1) +check_mod_val <- function(model, new = FALSE, existence = FALSE) { + if (is_missing(model) || length(model) != 1) stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", call. = FALSE) @@ -74,15 +74,15 @@ check_mod_val <- function(mod, new = FALSE, existence = FALSE) { } if (new) { - if (any(current$models == mod)) { - stop("Model `", mod, "` already exists", call. = FALSE) + if (any(current$models == model)) { + stop("Model `", model, "` already exists", call. = FALSE) } } if (existence) { current <- get_model_env() - if (!any(current$models == mod)) { - stop("Model `", mod, "` has not been registered.", call. = FALSE) + if (!any(current$models == model)) { + stop("Model `", model, "` has not been registered.", call. = FALSE) } } @@ -250,29 +250,29 @@ check_pkg_val <- function(x) { #' @rdname get_model_env #' @keywords internal #' @export -set_new_model <- function(mod) { - check_mod_val(mod, new = TRUE) +set_new_model <- function(model) { + check_mod_val(model, new = TRUE) current <- get_model_env() - current$models <- c(current$models, mod) - current[[mod]] <- dplyr::tibble(engine = character(0), mode = character(0)) - current[[paste0(mod, "_pkgs")]] <- dplyr::tibble(engine = character(0), pkg = list()) - current[[paste0(mod, "_modes")]] <- "unknown" - current[[paste0(mod, "_args")]] <- + current$models <- c(current$models, model) + current[[model]] <- dplyr::tibble(engine = character(0), mode = character(0)) + current[[paste0(model, "_pkgs")]] <- dplyr::tibble(engine = character(0), pkg = list()) + current[[paste0(model, "_modes")]] <- "unknown" + current[[paste0(model, "_args")]] <- dplyr::tibble( engine = character(0), parsnip = character(0), original = character(0), func = list() ) - current[[paste0(mod, "_fit")]] <- + current[[paste0(model, "_fit")]] <- dplyr::tibble( engine = character(0), mode = character(0), value = list() ) - current[[paste0(mod, "_predict")]] <- + current[[paste0(model, "_predict")]] <- dplyr::tibble( engine = character(0), mode = character(0), @@ -288,8 +288,8 @@ set_new_model <- function(mod) { #' @rdname get_model_env #' @keywords internal #' @export -set_model_mode <- function(mod, mode) { - check_mod_val(mod, existence = TRUE) +set_model_mode <- function(model, mode) { + check_mod_val(model, existence = TRUE) check_mode_val(mode) current <- get_model_env() @@ -297,8 +297,8 @@ set_model_mode <- function(mod, mode) { if (!any(current$modes == mode)) { current$modes <- unique(c(current$modes, mode)) } - current[[paste0(mod, "_modes")]] <- - unique(c(current[[paste0(mod, "_modes")]], mode)) + current[[paste0(model, "_modes")]] <- + unique(c(current[[paste0(model, "_modes")]], mode)) invisible(NULL) } @@ -308,21 +308,21 @@ set_model_mode <- function(mod, mode) { #' @rdname get_model_env #' @keywords internal #' @export -set_model_engine <- function(mod, mode, eng) { - check_mod_val(mod, existence = TRUE) +set_model_engine <- function(model, mode, eng) { + check_mod_val(model, existence = TRUE) check_mode_val(mode) check_mode_val(eng) current <- get_model_env() new_eng <- dplyr::tibble(engine = eng, mode = mode) - old_eng <- current[[mod]] + old_eng <- current[[model]] engs <- old_eng %>% dplyr::bind_rows(new_eng) %>% dplyr::distinct() - current[[mod]] <- engs + current[[model]] <- engs invisible(NULL) } @@ -333,15 +333,15 @@ set_model_engine <- function(mod, mode, eng) { #' @rdname get_model_env #' @keywords internal #' @export -set_model_arg <- function(mod, eng, val, original, func, submodels) { - check_mod_val(mod, existence = TRUE) +set_model_arg <- function(model, eng, val, original, func, submodels) { + check_mod_val(model, existence = TRUE) check_arg_val(val) check_arg_val(original) check_func_val(func) check_submodels_val(submodels) current <- get_model_env() - old_args <- current[[paste0(mod, "_args")]] + old_args <- current[[paste0(model, "_args")]] new_arg <- dplyr::tibble( @@ -361,7 +361,7 @@ set_model_arg <- function(mod, eng, val, original, func, submodels) { updated <- dplyr::distinct(updated, engine, parsnip, original, submodels) - current[[paste0(mod, "_args")]] <- updated + current[[paste0(model, "_args")]] <- updated invisible(NULL) } @@ -372,13 +372,13 @@ set_model_arg <- function(mod, eng, val, original, func, submodels) { #' @rdname get_model_env #' @keywords internal #' @export -set_dependency <- function(mod, eng, pkg) { - check_mod_val(mod, existence = TRUE) +set_dependency <- function(model, eng, pkg) { + check_mod_val(model, existence = TRUE) check_pkg_val(pkg) current <- get_model_env() - model_info <- current[[mod]] - pkg_info <- current[[paste0(mod, "_pkgs")]] + model_info <- current[[model]] + pkg_info <- current[[paste0(model, "_pkgs")]] has_engine <- model_info %>% @@ -387,7 +387,7 @@ set_dependency <- function(mod, eng, pkg) { nrow() if (has_engine != 1) { stop("The engine '", eng, "' has not been registered for model '", - mod, "'. ", call. = FALSE) + model, "'. ", call. = FALSE) } existing_pkgs <- @@ -407,7 +407,7 @@ set_dependency <- function(mod, eng, pkg) { dplyr::filter(engine != eng) %>% dplyr::bind_rows(existing_pkgs) } - current[[paste0(mod, "_pkgs")]] <- pkg_info + current[[paste0(model, "_pkgs")]] <- pkg_info invisible(NULL) } @@ -415,11 +415,11 @@ set_dependency <- function(mod, eng, pkg) { #' @rdname get_model_env #' @keywords internal #' @export -get_dependency <- function(mod) { - check_mod_val(mod, existence = TRUE) - pkg_name <- paste0(mod, "_pkgs") +get_dependency <- function(model) { + check_mod_val(model, existence = TRUE) + pkg_name <- paste0(model, "_pkgs") if (!any(pkg_name != rlang::env_names(get_model_env()))) { - stop("`", mod, "` does not have a dependency list in parsnip.", call. = FALSE) + stop("`", model, "` does not have a dependency list in parsnip.", call. = FALSE) } rlang::env_get(get_model_env(), pkg_name) } @@ -430,15 +430,15 @@ get_dependency <- function(mod) { #' @rdname get_model_env #' @keywords internal #' @export -set_fit <- function(mod, mode, eng, value) { - check_mod_val(mod, existence = TRUE) +set_fit <- function(model, mode, eng, value) { + check_mod_val(model, existence = TRUE) check_mode_val(mode) check_engine_val(eng) check_fit_info(value) current <- get_model_env() - model_info <- current[[paste0(mod)]] - old_fits <- current[[paste0(mod, "_fit")]] + model_info <- current[[paste0(model)]] + old_fits <- current[[paste0(model, "_fit")]] has_engine <- model_info %>% @@ -447,7 +447,7 @@ set_fit <- function(mod, mode, eng, value) { if (has_engine != 1) { stop("set_fit The combination of engine '", eng, "' and mode '", mode, "' has not been registered for model '", - mod, "'. ", call. = FALSE) + model, "'. ", call. = FALSE) } has_fit <- @@ -458,7 +458,7 @@ set_fit <- function(mod, mode, eng, value) { if (has_fit > 0) { stop("The combination of engine '", eng, "' and mode '", mode, "' already has a fit component for model '", - mod, "'. ", call. = FALSE) + model, "'. ", call. = FALSE) } new_fit <- @@ -473,7 +473,7 @@ set_fit <- function(mod, mode, eng, value) { stop("An error occured when adding the new fit module", call. = FALSE) } - current[[paste0(mod, "_fit")]] <- updated + current[[paste0(model, "_fit")]] <- updated invisible(NULL) } @@ -481,11 +481,11 @@ set_fit <- function(mod, mode, eng, value) { #' @rdname get_model_env #' @keywords internal #' @export -get_fit <- function(mod) { - check_mod_val(mod, existence = TRUE) - fit_name <- paste0(mod, "_fit") +get_fit <- function(model) { + check_mod_val(model, existence = TRUE) + fit_name <- paste0(model, "_fit") if (!any(fit_name != rlang::env_names(get_model_env()))) { - stop("`", mod, "` does not have a `fit` method in parsnip.", call. = FALSE) + stop("`", model, "` does not have a `fit` method in parsnip.", call. = FALSE) } rlang::env_get(get_model_env(), fit_name) } @@ -495,15 +495,15 @@ get_fit <- function(mod) { #' @rdname get_model_env #' @keywords internal #' @export -set_pred <- function(mod, mode, eng, type, value) { - check_mod_val(mod, existence = TRUE) +set_pred <- function(model, mode, eng, type, value) { + check_mod_val(model, existence = TRUE) check_mode_val(mode) check_engine_val(eng) check_pred_info(value, type) current <- get_model_env() - model_info <- current[[paste0(mod)]] - old_fits <- current[[paste0(mod, "_predict")]] + model_info <- current[[paste0(model)]] + old_fits <- current[[paste0(model, "_predict")]] has_engine <- model_info %>% @@ -512,7 +512,7 @@ set_pred <- function(mod, mode, eng, type, value) { if (has_engine != 1) { stop("The combination of engine '", eng, "' and mode '", mode, "' has not been registered for model '", - mod, "'. ", call. = FALSE) + model, "'. ", call. = FALSE) } has_pred <- @@ -523,7 +523,7 @@ set_pred <- function(mod, mode, eng, type, value) { stop("The combination of engine '", eng, "', mode '", mode, "', and type '", type, "' already has a prediction component for model '", - mod, "'. ", call. = FALSE) + model, "'. ", call. = FALSE) } new_fit <- @@ -539,7 +539,7 @@ set_pred <- function(mod, mode, eng, type, value) { stop("An error occured when adding the new fit module", call. = FALSE) } - current[[paste0(mod, "_predict")]] <- updated + current[[paste0(model, "_predict")]] <- updated invisible(NULL) } @@ -547,15 +547,15 @@ set_pred <- function(mod, mode, eng, type, value) { #' @rdname get_model_env #' @keywords internal #' @export -get_pred_type <- function(mod, type) { - check_mod_val(mod, existence = TRUE) - pred_name <- paste0(mod, "_predict") +get_pred_type <- function(model, type) { + check_mod_val(model, existence = TRUE) + pred_name <- paste0(model, "_predict") if (!any(pred_name != rlang::env_names(get_model_env()))) { - stop("`", mod, "` does not have any `pred` methods in parsnip.", call. = FALSE) + stop("`", model, "` does not have any `pred` methods in parsnip.", call. = FALSE) } all_preds <- rlang::env_get(get_model_env(), pred_name) if (!any(all_preds$type == type)) { - stop("`", mod, "` does not have any `", type, + stop("`", model, "` does not have any `", type, "` prediction methods in parsnip.", call. = FALSE) } dplyr::filter(all_preds, type == !!type) @@ -564,7 +564,7 @@ get_pred_type <- function(mod, type) { # ------------------------------------------------------------------------------ #' @export -validate_model <- function(mod) { +validate_model <- function(model) { # check for consistency across engines, modes, args, etc } @@ -573,19 +573,19 @@ validate_model <- function(mod) { #' @rdname get_model_env #' @keywords internal #' @export -show_model_info <- function(mod) { - check_mod_val(mod, existence = TRUE) +show_model_info <- function(model) { + check_mod_val(model, existence = TRUE) current <- get_model_env() - cat("Information for `", mod, "`\n", sep = "") + cat("Information for `", model, "`\n", sep = "") cat( " modes:", - paste0(current[[paste0(mod, "_modes")]], collapse = ", "), + paste0(current[[paste0(model, "_modes")]], collapse = ", "), "\n" ) - engines <- current[[paste0(mod)]] + engines <- current[[paste0(model)]] if (nrow(engines) > 0) { cat(" engines: \n") engines %>% @@ -606,7 +606,7 @@ show_model_info <- function(mod) { cat(" no registered engines yet.") } - args <- current[[paste0(mod, "_args")]] + args <- current[[paste0(model, "_args")]] if (nrow(args) > 0) { cat(" arguments: \n") args %>% @@ -629,7 +629,7 @@ show_model_info <- function(mod) { cat(" no registered arguments yet.") } - fits <- current[[paste0(mod, "_fits")]] + fits <- current[[paste0(model, "_fits")]] if (nrow(fits) > 0) { } else { diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index f3688717b..01815f313 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -10,7 +10,7 @@ set_model_engine("boost_tree", "regression", "xgboost") set_dependency("boost_tree", "xgboost", "xgboost") set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "tree_depth", original = "max_depth", @@ -18,7 +18,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "trees", original = "nrounds", @@ -26,7 +26,7 @@ set_model_arg( submodels = TRUE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "learn_rate", original = "eta", @@ -34,7 +34,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "mtry", original = "colsample_bytree", @@ -42,7 +42,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "min_n", original = "min_child_weight", @@ -50,7 +50,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "loss_reduction", original = "gamma", @@ -58,7 +58,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", val = "sample_size", original = "subsample", @@ -67,7 +67,7 @@ set_model_arg( ) set_fit( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "regression", value = list( @@ -79,7 +79,7 @@ set_fit( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "regression", type = "numeric", @@ -92,7 +92,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "regression", type = "raw", @@ -105,7 +105,7 @@ set_pred( ) set_fit( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "classification", value = list( @@ -117,7 +117,7 @@ set_fit( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "classification", type = "class", @@ -137,7 +137,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "classification", type = "prob", @@ -158,7 +158,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "xgboost", mode = "classification", type = "raw", @@ -176,7 +176,7 @@ set_model_engine("boost_tree", "classification", "C5.0") set_dependency("boost_tree", "C5.0", "C50") set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", val = "trees", original = "trials", @@ -184,7 +184,7 @@ set_model_arg( submodels = TRUE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", val = "min_n", original = "minCases", @@ -192,7 +192,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", val = "sample_size", original = "sample", @@ -201,7 +201,7 @@ set_model_arg( ) set_fit( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", mode = "classification", value = list( @@ -213,7 +213,7 @@ set_fit( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", mode = "classification", type = "class", @@ -226,7 +226,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", mode = "classification", type = "prob", @@ -246,7 +246,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "C5.0", mode = "classification", type = "raw", @@ -266,7 +266,7 @@ set_model_engine("boost_tree", "regression", "spark") set_dependency("boost_tree", "spark", "sparklyr") set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "tree_depth", original = "max_depth", @@ -274,7 +274,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "trees", original = "max_iter", @@ -282,7 +282,7 @@ set_model_arg( submodels = TRUE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "learn_rate", original = "step_size", @@ -290,7 +290,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "mtry", original = "feature_subset_strategy", @@ -298,7 +298,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "min_n", original = "min_instances_per_node", @@ -306,7 +306,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "min_info_gain", original = "gamma", @@ -314,7 +314,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "boost_tree", + model = "boost_tree", eng = "spark", val = "sample_size", original = "subsampling_rate", @@ -323,7 +323,7 @@ set_model_arg( ) set_fit( - mod = "boost_tree", + model = "boost_tree", eng = "spark", mode = "regression", value = list( @@ -335,7 +335,7 @@ set_fit( ) set_fit( - mod = "boost_tree", + model = "boost_tree", eng = "spark", mode = "classification", value = list( @@ -347,7 +347,7 @@ set_fit( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "spark", mode = "regression", type = "numeric", @@ -360,7 +360,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "spark", mode = "classification", type = "class", @@ -373,7 +373,7 @@ set_pred( ) set_pred( - mod = "boost_tree", + model = "boost_tree", eng = "spark", mode = "classification", type = "prob", diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index f319c56dc..307ce6b35 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -10,7 +10,7 @@ set_model_engine("decision_tree", "regression", "rpart") set_dependency("decision_tree", "rpart", "rpart") set_model_arg( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", val = "tree_depth", original = "maxdepth", @@ -19,7 +19,7 @@ set_model_arg( ) set_model_arg( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", val = "min_n", original = "minsplit", @@ -28,7 +28,7 @@ set_model_arg( ) set_model_arg( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", val = "cost_complexity", original = "cp", @@ -37,7 +37,7 @@ set_model_arg( ) set_fit( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "regression", value = list( @@ -49,7 +49,7 @@ set_fit( ) set_fit( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "classification", value = list( @@ -61,7 +61,7 @@ set_fit( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "regression", type = "numeric", @@ -74,7 +74,7 @@ set_pred( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "regression", type = "raw", @@ -87,7 +87,7 @@ set_pred( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "classification", type = "class", @@ -105,7 +105,7 @@ set_pred( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "classification", type = "prob", @@ -120,7 +120,7 @@ set_pred( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "rpart", mode = "classification", type = "raw", @@ -138,7 +138,7 @@ set_model_engine("decision_tree", "classification", "C5.0") set_dependency("decision_tree", "C5.0", "C5.0") set_model_arg( - mod = "decision_tree", + model = "decision_tree", eng = "C5.0", val = "min_n", original = "minCases", @@ -147,7 +147,7 @@ set_model_arg( ) set_fit( - mod = "decision_tree", + model = "decision_tree", eng = "C5.0", mode = "classification", value = list( @@ -159,7 +159,7 @@ set_fit( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "C5.0", mode = "classification", type = "class", @@ -173,7 +173,7 @@ set_pred( set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "C5.0", mode = "classification", type = "prob", @@ -194,7 +194,7 @@ set_pred( set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "C5.0", mode = "classification", type = "raw", @@ -214,7 +214,7 @@ set_model_engine("decision_tree", "regression", "spark") set_dependency("decision_tree", "spark", "spark") set_model_arg( - mod = "decision_tree", + model = "decision_tree", eng = "spark", val = "tree_depth", original = "max_depth", @@ -223,7 +223,7 @@ set_model_arg( ) set_model_arg( - mod = "decision_tree", + model = "decision_tree", eng = "spark", val = "min_n", original = "min_instances_per_node", @@ -232,7 +232,7 @@ set_model_arg( ) set_fit( - mod = "decision_tree", + model = "decision_tree", eng = "spark", mode = "regression", value = list( @@ -245,7 +245,7 @@ set_fit( ) set_fit( - mod = "decision_tree", + model = "decision_tree", eng = "spark", mode = "classification", value = list( @@ -258,7 +258,7 @@ set_fit( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "spark", mode = "regression", type = "numeric", @@ -271,7 +271,7 @@ set_pred( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "spark", mode = "classification", type = "class", @@ -284,7 +284,7 @@ set_pred( ) set_pred( - mod = "decision_tree", + model = "decision_tree", eng = "spark", mode = "classification", type = "prob", diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 334d6e91d..8779ad115 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -8,7 +8,7 @@ set_model_engine("linear_reg", "regression", "lm") set_dependency("linear_reg", "lm", "stats") set_fit( - mod = "linear_reg", + model = "linear_reg", eng = "lm", mode = "regression", value = list( @@ -20,7 +20,7 @@ set_fit( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "lm", mode = "regression", type = "numeric", @@ -38,7 +38,7 @@ set_pred( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "lm", mode = "regression", type = "conf_int", @@ -61,7 +61,7 @@ set_pred( ) ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "lm", mode = "regression", type = "pred_int", @@ -85,7 +85,7 @@ set_pred( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "lm", mode = "regression", type = "raw", @@ -103,7 +103,7 @@ set_model_engine("linear_reg", "regression", "glmnet") set_dependency("linear_reg", "glmnet", "glmnet") set_model_arg( - mod = "linear_reg", + model = "linear_reg", eng = "glmnet", val = "penalty", original = "lambda", @@ -112,7 +112,7 @@ set_model_arg( ) set_model_arg( - mod = "linear_reg", + model = "linear_reg", eng = "glmnet", val = "mixture", original = "alpha", @@ -121,7 +121,7 @@ set_model_arg( ) set_fit( - mod = "linear_reg", + model = "linear_reg", eng = "glmnet", mode = "regression", value = list( @@ -133,7 +133,7 @@ set_fit( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "glmnet", mode = "regression", type = "numeric", @@ -152,7 +152,7 @@ set_pred( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "glmnet", mode = "regression", type = "raw", @@ -172,7 +172,7 @@ set_model_engine("linear_reg", "regression", "stan") set_dependency("linear_reg", "stan", "rstanarm") set_fit( - mod = "linear_reg", + model = "linear_reg", eng = "stan", mode = "regression", value = list( @@ -184,7 +184,7 @@ set_fit( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "stan", mode = "regression", type = "numeric", @@ -197,7 +197,7 @@ set_pred( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "stan", mode = "regression", type = "conf_int", @@ -234,7 +234,7 @@ set_pred( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "stan", mode = "regression", type = "pred_int", @@ -270,7 +270,7 @@ set_pred( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "stan", mode = "regression", type = "raw", @@ -288,7 +288,7 @@ set_model_engine("linear_reg", "regression", "spark") set_dependency("linear_reg", "spark", "sparklyr") set_model_arg( - mod = "linear_reg", + model = "linear_reg", eng = "spark", val = "penalty", original = "reg_param", @@ -297,7 +297,7 @@ set_model_arg( ) set_model_arg( - mod = "linear_reg", + model = "linear_reg", eng = "spark", val = "mixture", original = "elastic_net_param", @@ -307,7 +307,7 @@ set_model_arg( set_fit( - mod = "linear_reg", + model = "linear_reg", eng = "spark", mode = "regression", value = list( @@ -319,7 +319,7 @@ set_fit( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "spark", mode = "regression", type = "numeric", @@ -343,7 +343,7 @@ set_dependency("linear_reg", "keras", "keras") set_dependency("linear_reg", "keras", "magrittr") set_fit( - mod = "linear_reg", + model = "linear_reg", eng = "keras", mode = "regression", value = list( @@ -355,7 +355,7 @@ set_fit( ) set_pred( - mod = "linear_reg", + model = "linear_reg", eng = "keras", mode = "regression", type = "numeric", diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index c8558d688..ae581aad3 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -8,7 +8,7 @@ set_model_engine("logistic_reg", "classification", "glm") set_dependency("logistic_reg", "glm", "stats") set_fit( - mod = "logistic_reg", + model = "logistic_reg", eng = "glm", mode = "classification", value = list( @@ -20,7 +20,7 @@ set_fit( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glm", mode = "classification", type = "class", @@ -38,7 +38,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glm", mode = "classification", type = "prob", @@ -60,7 +60,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glm", mode = "classification", type = "raw", @@ -77,7 +77,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glm", mode = "classification", type = "conf_int", @@ -122,7 +122,7 @@ set_model_engine("logistic_reg", "classification", "glmnet") set_dependency("logistic_reg", "glmnet", "glmnet") set_model_arg( - mod = "logistic_reg", + model = "logistic_reg", eng = "glmnet", val = "penalty", original = "lambda", @@ -131,7 +131,7 @@ set_model_arg( ) set_model_arg( - mod = "logistic_reg", + model = "logistic_reg", eng = "glmnet", val = "mixture", original = "alpha", @@ -140,7 +140,7 @@ set_model_arg( ) set_fit( - mod = "logistic_reg", + model = "logistic_reg", eng = "glmnet", mode = "classification", value = list( @@ -153,7 +153,7 @@ set_fit( set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glmnet", mode = "classification", type = "class", @@ -172,7 +172,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glmnet", mode = "classification", type = "prob", @@ -191,7 +191,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "glmnet", mode = "classification", type = "raw", @@ -213,7 +213,7 @@ set_model_engine("logistic_reg", "classification", "spark") set_dependency("logistic_reg", "spark", "sparklyr") set_model_arg( - mod = "logistic_reg", + model = "logistic_reg", eng = "spark", val = "penalty", original = "reg_param", @@ -222,7 +222,7 @@ set_model_arg( ) set_model_arg( - mod = "logistic_reg", + model = "logistic_reg", eng = "spark", val = "mixture", original = "elastic_net_param", @@ -231,7 +231,7 @@ set_model_arg( ) set_fit( - mod = "logistic_reg", + model = "logistic_reg", eng = "spark", mode = "classification", value = list( @@ -246,7 +246,7 @@ set_fit( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "spark", mode = "classification", type = "class", @@ -263,7 +263,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "spark", mode = "classification", type = "prob", @@ -287,7 +287,7 @@ set_dependency("logistic_reg", "keras", "keras") set_dependency("logistic_reg", "keras", "magrittr") set_model_arg( - mod = "logistic_reg", + model = "logistic_reg", eng = "keras", val = "decay", original = "decay", @@ -296,7 +296,7 @@ set_model_arg( ) set_fit( - mod = "logistic_reg", + model = "logistic_reg", eng = "keras", mode = "classification", value = list( @@ -308,7 +308,7 @@ set_fit( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "keras", mode = "classification", type = "class", @@ -327,7 +327,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "keras", mode = "classification", type = "prob", @@ -354,7 +354,7 @@ set_model_engine("logistic_reg", "regression", "stan") set_dependency("logistic_reg", "stan", "rstanarm") set_fit( - mod = "logistic_reg", + model = "logistic_reg", eng = "stan", mode = "classification", value = list( @@ -366,7 +366,7 @@ set_fit( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "stan", mode = "classification", type = "class", @@ -387,7 +387,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "stan", mode = "classification", type = "prob", @@ -410,7 +410,7 @@ set_pred( set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "stan", mode = "classification", type = "raw", @@ -427,7 +427,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "stan", mode = "classification", type = "conf_int", @@ -472,7 +472,7 @@ set_pred( ) set_pred( - mod = "logistic_reg", + model = "logistic_reg", eng = "stan", mode = "classification", type = "pred_int", diff --git a/R/mars_data.R b/R/mars_data.R index f3b4ae589..98dc59133 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -11,7 +11,7 @@ set_model_engine("mars", "regression", "earth") set_dependency("mars", "earth", "earth") set_model_arg( - mod = "mars", + model = "mars", eng = "earth", val = "num_terms", original = "nprune", @@ -19,7 +19,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mars", + model = "mars", eng = "earth", val = "prod_degree", original = "degree", @@ -27,7 +27,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mars", + model = "mars", eng = "earth", val = "prune_method", original = "pmethod", @@ -36,7 +36,7 @@ set_model_arg( ) set_fit( - mod = "mars", + model = "mars", eng = "earth", mode = "regression", value = list( @@ -48,7 +48,7 @@ set_fit( ) set_fit( - mod = "mars", + model = "mars", eng = "earth", mode = "classification", value = list( @@ -60,7 +60,7 @@ set_fit( ) set_pred( - mod = "mars", + model = "mars", eng = "earth", mode = "regression", type = "numeric", @@ -78,7 +78,7 @@ set_pred( ) set_pred( - mod = "mars", + model = "mars", eng = "earth", mode = "regression", type = "raw", @@ -93,7 +93,7 @@ set_pred( ) set_pred( - mod = "mars", + model = "mars", eng = "earth", mode = "classification", type = "class", @@ -114,7 +114,7 @@ set_pred( ) set_pred( - mod = "mars", + model = "mars", eng = "earth", mode = "classification", type = "prob", @@ -137,7 +137,7 @@ set_pred( ) set_pred( - mod = "mars", + model = "mars", eng = "earth", mode = "classification", type = "raw", diff --git a/R/misc.R b/R/misc.R index 2e2830429..7dda7d0df 100644 --- a/R/misc.R +++ b/R/misc.R @@ -144,8 +144,8 @@ is_spark <- function(x) isTRUE(unname(x$method$fit$func["pkg"] == "sparklyr")) -show_fit <- function(mod, eng) { - mod <- translate(x = mod, engine = eng) +show_fit <- function(model, eng) { + mod <- translate(x = model, engine = eng) fit_call <- show_call(mod) call_text <- deparse(fit_call) call_text <- paste0(call_text, collapse = "\n") diff --git a/R/mlp_data.R b/R/mlp_data.R index a707f4cca..a08fe51b9 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -12,7 +12,7 @@ set_dependency("mlp", "keras", "keras") set_dependency("mlp", "keras", "magrittr") set_model_arg( - mod = "mlp", + model = "mlp", eng = "keras", val = "hidden_units", original = "hidden_units", @@ -20,7 +20,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mlp", + model = "mlp", eng = "keras", val = "penalty", original = "penalty", @@ -28,7 +28,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mlp", + model = "mlp", eng = "keras", val = "dropout", original = "dropout", @@ -36,7 +36,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mlp", + model = "mlp", eng = "keras", val = "epochs", original = "epochs", @@ -44,7 +44,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mlp", + model = "mlp", eng = "keras", val = "activation", original = "activation", @@ -54,7 +54,7 @@ set_model_arg( set_fit( - mod = "mlp", + model = "mlp", eng = "keras", mode = "regression", value = list( @@ -66,7 +66,7 @@ set_fit( ) set_fit( - mod = "mlp", + model = "mlp", eng = "keras", mode = "classification", value = list( @@ -78,7 +78,7 @@ set_fit( ) set_pred( - mod = "mlp", + model = "mlp", eng = "keras", mode = "regression", type = "numeric", @@ -95,7 +95,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "keras", mode = "regression", type = "raw", @@ -113,7 +113,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "keras", mode = "classification", type = "class", @@ -132,7 +132,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "keras", mode = "classification", type = "prob", @@ -153,7 +153,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "keras", mode = "classification", type = "raw", @@ -176,7 +176,7 @@ set_model_engine("mlp", "regression", "nnet") set_dependency("mlp", "nnet", "nnet") set_model_arg( - mod = "mlp", + model = "mlp", eng = "nnet", val = "hidden_units", original = "size", @@ -184,7 +184,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mlp", + model = "mlp", eng = "nnet", val = "penalty", original = "decay", @@ -192,7 +192,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "mlp", + model = "mlp", eng = "nnet", val = "epochs", original = "maxit", @@ -200,7 +200,7 @@ set_model_arg( submodels = FALSE ) set_fit( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "regression", value = list( @@ -212,7 +212,7 @@ set_fit( ) set_fit( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "classification", value = list( @@ -224,7 +224,7 @@ set_fit( ) set_pred( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "regression", type = "numeric", @@ -242,7 +242,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "regression", type = "raw", @@ -260,7 +260,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "classification", type = "class", @@ -278,7 +278,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "classification", type = "prob", @@ -296,7 +296,7 @@ set_pred( ) set_pred( - mod = "mlp", + model = "mlp", eng = "nnet", mode = "classification", type = "raw", diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 418422c43..5116d9900 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -8,7 +8,7 @@ set_model_engine("multinom_reg", "classification", "glmnet") set_dependency("multinom_reg", "glmnet", "glmnet") set_model_arg( - mod = "multinom_reg", + model = "multinom_reg", eng = "glmnet", val = "penalty", original = "lambda", @@ -17,7 +17,7 @@ set_model_arg( ) set_model_arg( - mod = "multinom_reg", + model = "multinom_reg", eng = "glmnet", val = "mixture", original = "alpha", @@ -26,7 +26,7 @@ set_model_arg( ) set_fit( - mod = "multinom_reg", + model = "multinom_reg", eng = "glmnet", mode = "classification", value = list( @@ -39,7 +39,7 @@ set_fit( set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "glmnet", mode = "classification", type = "class", @@ -58,7 +58,7 @@ set_pred( ) set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "glmnet", mode = "classification", type = "prob", @@ -77,7 +77,7 @@ set_pred( ) set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "glmnet", mode = "classification", type = "raw", @@ -99,7 +99,7 @@ set_model_engine("multinom_reg", "classification", "spark") set_dependency("multinom_reg", "spark", "sparklyr") set_model_arg( - mod = "multinom_reg", + model = "multinom_reg", eng = "spark", val = "penalty", original = "reg_param", @@ -108,7 +108,7 @@ set_model_arg( ) set_model_arg( - mod = "multinom_reg", + model = "multinom_reg", eng = "spark", val = "mixture", original = "elastic_net_param", @@ -117,7 +117,7 @@ set_model_arg( ) set_fit( - mod = "multinom_reg", + model = "multinom_reg", eng = "spark", mode = "classification", value = list( @@ -129,7 +129,7 @@ set_fit( ) set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "spark", mode = "classification", type = "class", @@ -147,7 +147,7 @@ set_pred( set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "spark", mode = "classification", type = "prob", @@ -171,7 +171,7 @@ set_dependency("multinom_reg", "keras", "keras") set_dependency("multinom_reg", "keras", "magrittr") set_model_arg( - mod = "multinom_reg", + model = "multinom_reg", eng = "keras", val = "decay", original = "decay", @@ -181,7 +181,7 @@ set_model_arg( set_fit( - mod = "multinom_reg", + model = "multinom_reg", eng = "keras", mode = "classification", value = list( @@ -193,7 +193,7 @@ set_fit( ) set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "keras", mode = "classification", type = "class", @@ -210,7 +210,7 @@ set_pred( ) set_pred( - mod = "multinom_reg", + model = "multinom_reg", eng = "keras", mode = "classification", type = "prob", diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 3c8edb10a..dcac57748 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -11,7 +11,7 @@ set_model_engine("nearest_neighbor", "regression", "kknn") set_dependency("nearest_neighbor", "kknn", "kknn") set_model_arg( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", val = "neighbors", original = "ks", @@ -19,7 +19,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", val = "weight_func", original = "kernel", @@ -27,7 +27,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", val = "dist_power", original = "distance", @@ -36,7 +36,7 @@ set_model_arg( ) set_fit( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "regression", value = list( @@ -48,7 +48,7 @@ set_fit( ) set_fit( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "classification", value = list( @@ -60,7 +60,7 @@ set_fit( ) set_pred( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "regression", type = "numeric", @@ -87,7 +87,7 @@ set_pred( ) set_pred( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "regression", type = "raw", @@ -104,7 +104,7 @@ set_pred( ) set_pred( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "classification", type = "class", @@ -129,7 +129,7 @@ set_pred( ) set_pred( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "classification", type = "prob", @@ -154,7 +154,7 @@ set_pred( ) set_pred( - mod = "nearest_neighbor", + model = "nearest_neighbor", eng = "kknn", mode = "classification", type = "raw", diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index e09bf398a..1aba4e6bf 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -10,7 +10,7 @@ set_model_engine("null_model", "regression", "parsnip") set_dependency("null_model", "parsnip", "parsnip") set_fit( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "regression", value = list( @@ -22,7 +22,7 @@ set_fit( ) set_fit( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "classification", value = list( @@ -34,7 +34,7 @@ set_fit( ) set_pred( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "regression", type = "numeric", @@ -52,7 +52,7 @@ set_pred( ) set_pred( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "regression", type = "raw", @@ -70,7 +70,7 @@ set_pred( ) set_pred( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "classification", type = "class", @@ -88,7 +88,7 @@ set_pred( ) set_pred( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "classification", type = "prob", @@ -109,7 +109,7 @@ set_pred( ) set_pred( - mod = "null_model", + model = "null_model", eng = "parsnip", mode = "classification", type = "raw", diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 64e47b927..2f9d336c9 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -87,7 +87,7 @@ set_model_engine("rand_forest", "regression", "ranger") set_dependency("rand_forest", "ranger", "ranger") set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", val = "mtry", original = "mtry", @@ -95,7 +95,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", val = "trees", original = "num.trees", @@ -103,7 +103,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", val = "min_n", original = "min.node.size", @@ -112,7 +112,7 @@ set_model_arg( ) set_fit( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "classification", value = list( @@ -129,7 +129,7 @@ set_fit( ) set_fit( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "regression", value = list( @@ -146,7 +146,7 @@ set_fit( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "classification", type = "class", @@ -166,7 +166,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "classification", type = "prob", @@ -196,7 +196,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "classification", type = "raw", @@ -214,7 +214,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "regression", type = "numeric", @@ -235,7 +235,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "ranger", mode = "regression", type = "raw", @@ -260,7 +260,7 @@ set_model_engine("rand_forest", "regression", "randomForest") set_dependency("rand_forest", "randomForest", "randomForest") set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", val = "mtry", original = "mtry", @@ -268,7 +268,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", val = "trees", original = "ntree", @@ -276,7 +276,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", val = "min_n", original = "nodesize", @@ -285,7 +285,7 @@ set_model_arg( ) set_fit( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "classification", value = list( @@ -298,7 +298,7 @@ set_fit( ) set_fit( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "regression", value = list( @@ -311,7 +311,7 @@ set_fit( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "regression", type = "numeric", @@ -326,7 +326,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "regression", type = "raw", @@ -341,7 +341,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "classification", type = "class", @@ -355,7 +355,7 @@ set_pred( set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "classification", type = "prob", @@ -375,7 +375,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "randomForest", mode = "classification", type = "raw", @@ -397,7 +397,7 @@ set_model_engine("rand_forest", "regression", "spark") set_dependency("rand_forest", "spark", "sparklyr") set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "spark", val = "mtry", original = "feature_subset_strategy", @@ -405,7 +405,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "spark", val = "trees", original = "num_trees", @@ -413,7 +413,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "rand_forest", + model = "rand_forest", eng = "spark", val = "min_n", original = "min_instances_per_node", @@ -422,7 +422,7 @@ set_model_arg( ) set_fit( - mod = "rand_forest", + model = "rand_forest", eng = "spark", mode = "classification", value = list( @@ -434,7 +434,7 @@ set_fit( ) set_fit( - mod = "rand_forest", + model = "rand_forest", eng = "spark", mode = "regression", value = list( @@ -446,7 +446,7 @@ set_fit( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "spark", mode = "regression", type = "numeric", @@ -461,7 +461,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "spark", mode = "classification", type = "class", @@ -476,7 +476,7 @@ set_pred( ) set_pred( - mod = "rand_forest", + model = "rand_forest", eng = "spark", mode = "classification", type = "prob", diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 4d05aa298..18868f823 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -9,7 +9,7 @@ set_dependency("surv_reg", "flexsurv", "flexsurv") set_dependency("surv_reg", "flexsurv", "survival") set_model_arg( - mod = "surv_reg", + model = "surv_reg", eng = "flexsurv", val = "dist", original = "dist", @@ -18,7 +18,7 @@ set_model_arg( ) set_fit( - mod = "surv_reg", + model = "surv_reg", eng = "flexsurv", mode = "regression", value = list( @@ -30,7 +30,7 @@ set_fit( ) set_pred( - mod = "surv_reg", + model = "surv_reg", eng = "flexsurv", mode = "regression", type = "numeric", @@ -48,7 +48,7 @@ set_pred( ) set_pred( - mod = "surv_reg", + model = "surv_reg", eng = "flexsurv", mode = "regression", type = "quantile", @@ -72,7 +72,7 @@ set_model_engine("surv_reg", "regression", "survival") set_dependency("surv_reg", "survival", "survival") set_model_arg( - mod = "surv_reg", + model = "surv_reg", eng = "survival", val = "dist", original = "dist", @@ -81,7 +81,7 @@ set_model_arg( ) set_fit( - mod = "surv_reg", + model = "surv_reg", eng = "survival", mode = "regression", value = list( @@ -93,7 +93,7 @@ set_fit( ) set_pred( - mod = "surv_reg", + model = "surv_reg", eng = "survival", mode = "regression", type = "numeric", @@ -111,7 +111,7 @@ set_pred( ) set_pred( - mod = "surv_reg", + model = "surv_reg", eng = "survival", mode = "regression", type = "quantile", diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index 89e551c78..8fc2da66d 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -10,7 +10,7 @@ set_model_engine("svm_poly", "regression", "kernlab") set_dependency("svm_poly", "kernlab", "kernlab") set_model_arg( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", val = "cost", original = "C", @@ -19,7 +19,7 @@ set_model_arg( ) set_model_arg( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", val = "degree", original = "degree", @@ -28,7 +28,7 @@ set_model_arg( ) set_model_arg( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", val = "scale_factor", original = "scale", @@ -36,7 +36,7 @@ set_model_arg( submodels = FALSE ) set_model_arg( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", val = "margin", original = "epsilon", @@ -45,7 +45,7 @@ set_model_arg( ) set_fit( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "regression", value = list( @@ -57,7 +57,7 @@ set_fit( ) set_fit( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "classification", value = list( @@ -69,7 +69,7 @@ set_fit( ) set_pred( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "regression", type = "numeric", @@ -87,7 +87,7 @@ set_pred( ) set_pred( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "regression", type = "raw", @@ -100,7 +100,7 @@ set_pred( ) set_pred( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "classification", type = "class", @@ -118,7 +118,7 @@ set_pred( ) set_pred( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "classification", type = "prob", @@ -136,7 +136,7 @@ set_pred( ) set_pred( - mod = "svm_poly", + model = "svm_poly", eng = "kernlab", mode = "classification", type = "raw", diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index 4a1203cf0..d489222ae 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -10,7 +10,7 @@ set_model_engine("svm_rbf", "regression", "kernlab") set_dependency("svm_rbf", "kernlab", "kernlab") set_model_arg( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", val = "cost", original = "C", @@ -19,7 +19,7 @@ set_model_arg( ) set_model_arg( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", val = "rbf_sigma", original = "sigma", @@ -28,7 +28,7 @@ set_model_arg( ) set_model_arg( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", val = "margin", original = "epsilon", @@ -37,7 +37,7 @@ set_model_arg( ) set_fit( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "regression", value = list( @@ -49,7 +49,7 @@ set_fit( ) set_fit( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "classification", value = list( @@ -61,7 +61,7 @@ set_fit( ) set_pred( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "regression", type = "numeric", @@ -79,7 +79,7 @@ set_pred( ) set_pred( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "regression", type = "raw", @@ -92,7 +92,7 @@ set_pred( ) set_pred( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "classification", type = "class", @@ -110,7 +110,7 @@ set_pred( ) set_pred( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "classification", type = "prob", @@ -128,7 +128,7 @@ set_pred( ) set_pred( - mod = "svm_rbf", + model = "svm_rbf", eng = "kernlab", mode = "classification", type = "raw", diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd index 585ee385d..1b4690f40 100644 --- a/man/check_mod_val.Rd +++ b/man/check_mod_val.Rd @@ -12,7 +12,7 @@ \alias{check_pkg_val} \title{Tools to Check Model Elements} \usage{ -check_mod_val(mod, new = FALSE, existence = FALSE) +check_mod_val(model, new = FALSE, existence = FALSE) check_mode_val(mode) @@ -31,7 +31,7 @@ check_pred_info(x, type) check_pkg_val(x) } \arguments{ -\item{mod}{A single character string for the model type (e.g. +\item{model}{A single character string for the model type (e.g. \code{"rand_forest"}, etc).} \item{new}{A single logical to check to see if the model that you are check diff --git a/man/get_model_env.Rd b/man/get_model_env.Rd index d434d8957..978dff676 100644 --- a/man/get_model_env.Rd +++ b/man/get_model_env.Rd @@ -17,27 +17,27 @@ \usage{ get_model_env() -set_new_model(mod) +set_new_model(model) -set_model_mode(mod, mode) +set_model_mode(model, mode) -set_model_engine(mod, mode, eng) +set_model_engine(model, mode, eng) -set_model_arg(mod, eng, val, original, func, submodels) +set_model_arg(model, eng, val, original, func, submodels) -set_dependency(mod, eng, pkg) +set_dependency(model, eng, pkg) -get_dependency(mod) +get_dependency(model) -set_fit(mod, mode, eng, value) +set_fit(model, mode, eng, value) -get_fit(mod) +get_fit(model) -set_pred(mod, mode, eng, type, value) +set_pred(model, mode, eng, type, value) -get_pred_type(mod, type) +get_pred_type(model, type) -show_model_info(mod) +show_model_info(model) } \description{ Tools to Register Models From fc6ccbe9a40138f1b56996fd9cce9d0b99a4addc Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 May 2019 06:11:50 -0400 Subject: [PATCH 23/64] re-enable Sexpr in man files --- NEWS.md | 1 + R/boost_tree.R | 10 +++++----- R/decision_tree.R | 10 +++++----- R/linear_reg.R | 10 +++++----- R/logistic_reg.R | 10 +++++----- R/mars.R | 4 ++-- R/mlp.R | 8 ++++---- R/multinom_reg.R | 6 +++--- R/nearest_neighbor.R | 2 +- R/nullmodel.R | 4 ++-- R/rand_forest.R | 12 ++++++------ R/surv_reg.R | 12 ++++++------ R/svm_poly.R | 4 ++-- R/svm_rbf.R | 4 ++-- man/boost_tree.Rd | 10 ++++++++++ man/decision_tree.Rd | 10 ++++++++++ man/linear_reg.Rd | 10 ++++++++++ man/logistic_reg.Rd | 10 ++++++++++ man/mars.Rd | 4 ++++ man/mlp.Rd | 8 ++++++++ man/multinom_reg.Rd | 6 ++++++ man/nearest_neighbor.Rd | 2 ++ man/null_model.Rd | 4 ++++ man/rand_forest.Rd | 12 ++++++++++++ man/surv_reg.Rd | 8 ++++++-- man/svm_poly.Rd | 4 ++++ man/svm_rbf.Rd | 4 ++++ 27 files changed, 139 insertions(+), 50 deletions(-) diff --git a/NEWS.md b/NEWS.md index fff23d1e9..47461648d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,7 @@ * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. * The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation). + * For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`. ## New Features diff --git a/R/boost_tree.R b/R/boost_tree.R index b1cdca096..61c16c87c 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -67,23 +67,23 @@ #' #' \pkg{xgboost} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "xgboost")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "xgboost")} #' #' \pkg{xgboost} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "xgboost")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "xgboost")} #' #' \pkg{C5.0} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "C5.0")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "C5.0")} #' #' \pkg{spark} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "spark")} #' #' \pkg{spark} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "spark")} #' #' @note For models created using the spark engine, there are #' several differences to consider. First, only the formula diff --git a/R/decision_tree.R b/R/decision_tree.R index 5032f990f..afa807405 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -52,23 +52,23 @@ #' #' \pkg{rpart} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} #' #' \pkg{rpart} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} #' #' \pkg{C5.0} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} #' #' \pkg{spark} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} #' #' \pkg{spark} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} #' #' @note For models created using the spark engine, there are #' several differences to consider. First, only the formula diff --git a/R/linear_reg.R b/R/linear_reg.R index 3f94a0cce..d3bc72c70 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -50,23 +50,23 @@ #' #' \pkg{lm} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "lm")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "lm")} #' #' \pkg{glmnet} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "glmnet")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "glmnet")} #' #' \pkg{stan} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")} #' #' \pkg{spark} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} #' #' \pkg{keras} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. This diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 8365c729a..1bc9062b6 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -48,23 +48,23 @@ #' #' \pkg{glm} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glm")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glm")} #' #' \pkg{glmnet} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glmnet")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glmnet")} #' #' \pkg{stan} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")} #' #' \pkg{spark} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} #' #' \pkg{keras} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. This diff --git a/R/mars.R b/R/mars.R index d3c1bfaae..bfe7f6cbf 100644 --- a/R/mars.R +++ b/R/mars.R @@ -44,11 +44,11 @@ #' #' \pkg{earth} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "classification"), "earth")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "classification"), "earth")} #' #' \pkg{earth} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "regression"), "earth")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "regression"), "earth")} #' #' Note that, when the model is fit, the \pkg{earth} package only has its #' namespace loaded. However, if `multi_predict` is used, the package is diff --git a/R/mlp.R b/R/mlp.R index d31b8726d..3fe96631f 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -62,19 +62,19 @@ #' #' \pkg{keras} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "keras")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "keras")} #' #' \pkg{keras} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "keras")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "keras")} #' #' \pkg{nnet} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "nnet")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "nnet")} #' #' \pkg{nnet} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 2c978e9ff..71e367489 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -47,15 +47,15 @@ #' #' \pkg{glmnet} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} #' #' \pkg{spark} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} #' #' \pkg{keras} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} #' #' When using `glmnet` models, there is the option to pass #' multiple values (or no values) to the `penalty` argument. This diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index eaa83188e..0bbca127b 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -54,7 +54,7 @@ #' #' \pkg{kknn} (classification or regression) #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(mode = "regression"), "kknn")} #' #' @note #' For `kknn`, the underlying modeling function used is a restricted diff --git a/R/nullmodel.R b/R/nullmodel.R index f20c7751a..10772562d 100644 --- a/R/nullmodel.R +++ b/R/nullmodel.R @@ -147,11 +147,11 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { #' #' \pkg{parsnip} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")} #' #' \pkg{parsnip} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/R/rand_forest.R b/R/rand_forest.R index 23ec1c1f9..352178e2b 100644 --- a/R/rand_forest.R +++ b/R/rand_forest.R @@ -46,27 +46,27 @@ #' #' \pkg{ranger} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "ranger")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "ranger")} #' #' \pkg{ranger} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "ranger")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "ranger")} #' #' \pkg{randomForests} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "randomForest")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "randomForest")} #' #' \pkg{randomForests} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "randomForest")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "randomForest")} #' #' \pkg{spark} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "spark")} #' #' \pkg{spark} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "spark")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "spark")} #' #' For \pkg{ranger} confidence intervals, the intervals are #' constructed using the form `estimate +/- z * std_error`. For diff --git a/R/surv_reg.R b/R/surv_reg.R index fa0187a10..878e20a9a 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -36,7 +36,7 @@ #' The model can be created using the `fit()` function using the #' following _engines_: #' \itemize{ -#' \item \pkg{R}: `"flexsurv"`, `"survreg"` (the default) +#' \item \pkg{R}: `"flexsurv"`, `"survival"` (the default) #' } #' #' @section Engine Details: @@ -47,11 +47,11 @@ #' #' \pkg{flexsurv} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} #' -#' \pkg{survreg} +#' \pkg{survival} #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survival")} #' #' Note that `model = TRUE` is needed to produce quantile #' predictions when there is a stratification variable and can be @@ -143,8 +143,8 @@ update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) { #' @export translate.surv_reg <- function(x, engine = x$engine, ...) { if (is.null(engine)) { - message("Used `engine = 'survreg'` for translation.") - engine <- "survreg" + message("Used `engine = 'survival'` for translation.") + engine <- "survival" } x <- translate.default(x, engine, ...) x diff --git a/R/svm_poly.R b/R/svm_poly.R index a2b1e8fe0..5eb071950 100644 --- a/R/svm_poly.R +++ b/R/svm_poly.R @@ -44,11 +44,11 @@ #' #' \pkg{kernlab} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} #' #' \pkg{kernlab} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/R/svm_rbf.R b/R/svm_rbf.R index b0ab171ba..0fe3d39a6 100644 --- a/R/svm_rbf.R +++ b/R/svm_rbf.R @@ -43,11 +43,11 @@ #' #' \pkg{kernlab} classification #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} #' #' \pkg{kernlab} regression #' -# \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} #' #' @importFrom purrr map_lgl #' @seealso [varying()], [fit()] diff --git a/man/boost_tree.Rd b/man/boost_tree.Rd index 6ebde7cca..904313d00 100644 --- a/man/boost_tree.Rd +++ b/man/boost_tree.Rd @@ -112,13 +112,23 @@ fit calls are: \pkg{xgboost} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "xgboost")} + \pkg{xgboost} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "xgboost")} + \pkg{C5.0} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "C5.0")} + \pkg{spark} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "classification"), "spark")} + \pkg{spark} regression + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::boost_tree(mode = "regression"), "spark")} } \examples{ diff --git a/man/decision_tree.Rd b/man/decision_tree.Rd index 4fbe4d18b..221bcdccb 100644 --- a/man/decision_tree.Rd +++ b/man/decision_tree.Rd @@ -88,13 +88,23 @@ model, the template of the fit calls are:: \pkg{rpart} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "rpart")} + \pkg{rpart} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "rpart")} + \pkg{C5.0} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "C5.0")} + \pkg{spark} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "classification"), "spark")} + \pkg{spark} regression + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::decision_tree(mode = "regression"), "spark")} } \examples{ diff --git a/man/linear_reg.Rd b/man/linear_reg.Rd index 3eecc57fc..b58c1d631 100644 --- a/man/linear_reg.Rd +++ b/man/linear_reg.Rd @@ -87,14 +87,24 @@ model, the template of the fit calls are: \pkg{lm} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "lm")} + \pkg{glmnet} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "glmnet")} + \pkg{stan} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "stan")} + \pkg{spark} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")} + \pkg{keras} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")} + When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using the diff --git a/man/logistic_reg.Rd b/man/logistic_reg.Rd index 4f23ceb2d..43aa599e9 100644 --- a/man/logistic_reg.Rd +++ b/man/logistic_reg.Rd @@ -85,14 +85,24 @@ model, the template of the fit calls are: \pkg{glm} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glm")} + \pkg{glmnet} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "glmnet")} + \pkg{stan} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "stan")} + \pkg{spark} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")} + \pkg{keras} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")} + When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using the diff --git a/man/mars.Rd b/man/mars.Rd index 5647bdcf0..b55c24241 100644 --- a/man/mars.Rd +++ b/man/mars.Rd @@ -67,8 +67,12 @@ model, the template of the fit calls are: \pkg{earth} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "classification"), "earth")} + \pkg{earth} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mars(mode = "regression"), "earth")} + Note that, when the model is fit, the \pkg{earth} package only has its namespace loaded. However, if \code{multi_predict} is used, the package is attached. diff --git a/man/mlp.Rd b/man/mlp.Rd index 0acb2655a..f52a60a80 100644 --- a/man/mlp.Rd +++ b/man/mlp.Rd @@ -91,11 +91,19 @@ model, the template of the fit calls are: \pkg{keras} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "keras")} + \pkg{keras} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "keras")} + \pkg{nnet} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "classification"), "nnet")} + \pkg{nnet} regression + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::mlp(mode = "regression"), "nnet")} } \examples{ diff --git a/man/multinom_reg.Rd b/man/multinom_reg.Rd index e3f8dd30d..6f2b4af05 100644 --- a/man/multinom_reg.Rd +++ b/man/multinom_reg.Rd @@ -84,10 +84,16 @@ model, the template of the fit calls are: \pkg{glmnet} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} + \pkg{spark} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} + \pkg{keras} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")} + When using \code{glmnet} models, there is the option to pass multiple values (or no values) to the \code{penalty} argument. This can have an effect on the model object results. When using the diff --git a/man/nearest_neighbor.Rd b/man/nearest_neighbor.Rd index 123d8f9ca..757de6369 100644 --- a/man/nearest_neighbor.Rd +++ b/man/nearest_neighbor.Rd @@ -66,6 +66,8 @@ model fit call. For this type of model, the template of the fit calls are: \pkg{kknn} (classification or regression) + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(mode = "regression"), "kknn")} } \examples{ diff --git a/man/null_model.Rd b/man/null_model.Rd index 0c221dcba..b0930770b 100644 --- a/man/null_model.Rd +++ b/man/null_model.Rd @@ -32,7 +32,11 @@ model, the template of the fit calls are: \pkg{parsnip} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")} + \pkg{parsnip} regression + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")} } \examples{ diff --git a/man/rand_forest.Rd b/man/rand_forest.Rd index 73e40c618..80c6c8028 100644 --- a/man/rand_forest.Rd +++ b/man/rand_forest.Rd @@ -82,16 +82,28 @@ model, the template of the fit calls are:: \pkg{ranger} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "ranger")} + \pkg{ranger} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "ranger")} + \pkg{randomForests} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "randomForest")} + \pkg{randomForests} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "randomForest")} + \pkg{spark} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "classification"), "spark")} + \pkg{spark} regression +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::rand_forest(mode = "regression"), "spark")} + For \pkg{ranger} confidence intervals, the intervals are constructed using the form \code{estimate +/- z * std_error}. For classification probabilities, these values can fall outside of diff --git a/man/surv_reg.Rd b/man/surv_reg.Rd index f6b307d12..5302daccc 100644 --- a/man/surv_reg.Rd +++ b/man/surv_reg.Rd @@ -58,7 +58,7 @@ For \code{surv_reg()}, the mode will always be "regression". The model can be created using the \code{fit()} function using the following \emph{engines}: \itemize{ -\item \pkg{R}: \code{"flexsurv"}, \code{"survreg"} (the default) +\item \pkg{R}: \code{"flexsurv"}, \code{"survival"} (the default) } } \section{Engine Details}{ @@ -70,7 +70,11 @@ model, the template of the fit calls are: \pkg{flexsurv} -\pkg{survreg} +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} + +\pkg{survival} + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survival")} Note that \code{model = TRUE} is needed to produce quantile predictions when there is a stratification variable and can be diff --git a/man/svm_poly.Rd b/man/svm_poly.Rd index 2b31741b6..314b55e49 100644 --- a/man/svm_poly.Rd +++ b/man/svm_poly.Rd @@ -69,7 +69,11 @@ model, the template of the fit calls are:: \pkg{kernlab} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "classification"), "kernlab")} + \pkg{kernlab} regression + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_poly(mode = "regression"), "kernlab")} } \examples{ diff --git a/man/svm_rbf.Rd b/man/svm_rbf.Rd index d7b503207..e815d9e38 100644 --- a/man/svm_rbf.Rd +++ b/man/svm_rbf.Rd @@ -67,7 +67,11 @@ model, the template of the fit calls are:: \pkg{kernlab} classification +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "classification"), "kernlab")} + \pkg{kernlab} regression + +\Sexpr[results=rd]{parsnip:::show_fit(parsnip:::svm_rbf(mode = "regression"), "kernlab")} } \examples{ From 41307450a8e3bee48546569248f3497ece4219f5 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Tue, 28 May 2019 08:11:05 -0400 Subject: [PATCH 24/64] changed some argument names to be a little more specific --- NAMESPACE | 1 + R/aaa_models.R | 73 +++++++++++++++++++++++++++------------ R/boost_tree_data.R | 68 ++++++++++++++++++------------------ R/decision_tree_data.R | 24 ++++++------- R/linear_reg_data.R | 16 ++++----- R/logistic_reg_data.R | 20 +++++------ R/mars_data.R | 12 +++---- R/mlp_data.R | 32 ++++++++--------- R/multinom_reg_data.R | 20 +++++------ R/nearest_neighbor_data.R | 12 +++---- R/rand_forest_data.R | 36 +++++++++---------- R/surv_reg_data.R | 8 ++--- R/svm_poly_data.R | 16 ++++----- R/svm_rbf_data.R | 12 +++---- man/check_mod_val.Rd | 48 ++++++++++++++++++++++--- man/get_model_env.Rd | 2 +- 16 files changed, 233 insertions(+), 167 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ca7a96068..d0ee74cf0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -110,6 +110,7 @@ export(multinom_reg) export(nearest_neighbor) export(null_model) export(nullmodel) +export(pred_types) export(predict.model_fit) export(rand_forest) export(rpart_train) diff --git a/R/aaa_models.R b/R/aaa_models.R index be79fe8f1..50c9a0161 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -30,6 +30,9 @@ parsnip$modes <- c("regression", "classification", "unknown") # ------------------------------------------------------------------------------ +#' @rdname check_mod_val +#' @keywords internal +#' @export pred_types <- c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile") @@ -62,6 +65,30 @@ get_model_env <- function() { #' @param mode A single character string for the model mode (e.g. "regression"). #' @param eng A single character string for the model engine. #' @param arg A single character string for the model argument name. +#' @param has_submodel A single logical for whether the argument +#' can make predictions on mutiple submodels at once. +#' @param func A named character vector that describes how to call +#' a function. `func` should have elements `pkg` and `fun`. The +#' former is optional but is recommended and the latter is +#' required. For example, `c(pkg = "stats", fun = "lm")` would be +#' used to invoke the usual linear regression function. In some +#' cases, it is helpful to use `c(fun = "predict")` when using a +#' package's `predict` method. +#' @param fit_obj A list with elements `interface`, `protect`, +#' `func` and `defaults`. See the package vignette "Making a +#' `parsnip` model from scratch". +#' @param pred_obj A list with elements `pre`, `post`, `func`, and `args`. +#' See the package vignette "Making a `parsnip` model from scratch". +#' @param type A single character value for the type of prediction. Possible +#' values are: +#' \Sexpr[results=rd]{paste0("'", parsnip::pred_types, "'", collapse = ", ")}. +#' @param pkg An options character string for a package name. +#' @param parsnip A single character string for the "harmonized" argument name +#' that `parsnip` exposes. +#' @param original A single character string for the argument name that +#' underlying model function uses. +#' @param value A list that conforms to the `fit_obj` or `pred_obj` description +#' above, depending on context. #' @keywords internal #' @export check_mod_val <- function(model, new = FALSE, existence = FALSE) { @@ -122,8 +149,8 @@ check_arg_val <- function(arg) { #' @rdname check_mod_val #' @keywords internal #' @export -check_submodels_val <- function(x) { - if (!is.logical(x) || length(x) != 1) { +check_submodels_val <- function(has_submodel) { + if (!is.logical(has_submodel) || length(has_submodel) != 1) { stop("The `submodels` argument should be a single logical.", call. = FALSE) } invisible(NULL) @@ -169,31 +196,31 @@ check_func_val <- function(func) { #' @rdname check_mod_val #' @keywords internal #' @export -check_fit_info <- function(x) { - if (is.null(x)) { +check_fit_info <- function(fit_obj) { + if (is.null(fit_obj)) { stop("The `fit` module cannot be NULL.", call. = FALSE) } exp_nms <- c("defaults", "func", "interface", "protect") - if (!isTRUE(all.equal(sort(names(x)), exp_nms))) { + if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) { stop("The `fit` module should have elements: ", paste0("`", exp_nms, "`", collapse = ", "), call. = FALSE) } exp_interf <- c("data.frame", "formula", "matrix") - if (length(x$interface) > 1) { + if (length(fit_obj$interface) > 1) { stop("The `interface` element should have a single value of : ", paste0("`", exp_interf, "`", collapse = ", "), call. = FALSE) } - if (!any(x$interface == exp_interf)) { + if (!any(fit_obj$interface == exp_interf)) { stop("The `interface` element should have a value of : ", paste0("`", exp_interf, "`", collapse = ", "), call. = FALSE) } - check_func_val(x$func) + check_func_val(fit_obj$func) - if (!is.list(x$defaults)) { + if (!is.list(fit_obj$defaults)) { stop("The `defaults` element should be a list: ", call. = FALSE) } @@ -203,7 +230,7 @@ check_fit_info <- function(x) { #' @rdname check_mod_val #' @keywords internal #' @export -check_pred_info <- function(x, type) { +check_pred_info <- function(pred_obj, type) { if (all(type != pred_types)) { stop("The prediction type should be one of: ", paste0("'", pred_types, "'", collapse = ", "), @@ -211,24 +238,24 @@ check_pred_info <- function(x, type) { } exp_nms <- c("args", "func", "post", "pre") - if (!isTRUE(all.equal(sort(names(x)), exp_nms))) { + if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) { stop("The `predict` module should have elements: ", paste0("`", exp_nms, "`", collapse = ", "), call. = FALSE) } - if (!is.null(x$pre) & !is.function(x$pre)) { + if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) { stop("The `pre` module should be null or a function: ", call. = FALSE) } - if (!is.null(x$post) & !is.function(x$post)) { + if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) { stop("The `post` module should be null or a function: ", call. = FALSE) } - check_func_val(x$func) + check_func_val(pred_obj$func) - if (!is.list(x$args)) { + if (!is.list(pred_obj$args)) { stop("The `args` element should be a list. ", call. = FALSE) } @@ -238,8 +265,8 @@ check_pred_info <- function(x, type) { #' @rdname check_mod_val #' @keywords internal #' @export -check_pkg_val <- function(x) { - if (is_missing(x) || length(x) != 1 || !is.character(x)) +check_pkg_val <- function(pkg) { + if (is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) stop("Please supply a single character vale for the package name", call. = FALSE) invisible(NULL) @@ -333,12 +360,12 @@ set_model_engine <- function(model, mode, eng) { #' @rdname get_model_env #' @keywords internal #' @export -set_model_arg <- function(model, eng, val, original, func, submodels) { +set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { check_mod_val(model, existence = TRUE) - check_arg_val(val) + check_arg_val(parsnip) check_arg_val(original) check_func_val(func) - check_submodels_val(submodels) + check_submodels_val(has_submodel) current <- get_model_env() old_args <- current[[paste0(model, "_args")]] @@ -346,10 +373,10 @@ set_model_arg <- function(model, eng, val, original, func, submodels) { new_arg <- dplyr::tibble( engine = eng, - parsnip = val, + parsnip = parsnip, original = original, func = list(func), - submodels = submodels + has_submodel = has_submodel ) # TODO cant currently use `distinct()` on a list column. @@ -359,7 +386,7 @@ set_model_arg <- function(model, eng, val, original, func, submodels) { stop("An error occured when adding the new argument.", call. = FALSE) } - updated <- dplyr::distinct(updated, engine, parsnip, original, submodels) + updated <- dplyr::distinct(updated, engine, parsnip, original, has_submodel) current[[paste0(model, "_args")]] <- updated diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 01815f313..98356f96b 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -12,58 +12,58 @@ set_dependency("boost_tree", "xgboost", "xgboost") set_model_arg( model = "boost_tree", eng = "xgboost", - val = "tree_depth", + parsnip = "tree_depth", original = "max_depth", func = list(pkg = "dials", fun = "tree_depth"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "xgboost", - val = "trees", + parsnip = "trees", original = "nrounds", func = list(pkg = "dials", fun = "trees"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "boost_tree", eng = "xgboost", - val = "learn_rate", + parsnip = "learn_rate", original = "eta", func = list(pkg = "dials", fun = "learn_rate"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "xgboost", - val = "mtry", + parsnip = "mtry", original = "colsample_bytree", func = list(pkg = "dials", fun = "mtry"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "xgboost", - val = "min_n", + parsnip = "min_n", original = "min_child_weight", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "xgboost", - val = "loss_reduction", + parsnip = "loss_reduction", original = "gamma", func = list(pkg = "dials", fun = "loss_reduction"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "xgboost", - val = "sample_size", + parsnip = "sample_size", original = "subsample", func = list(pkg = "dials", fun = "sample_size"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -178,26 +178,26 @@ set_dependency("boost_tree", "C5.0", "C50") set_model_arg( model = "boost_tree", eng = "C5.0", - val = "trees", + parsnip = "trees", original = "trials", func = list(pkg = "dials", fun = "trees"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "boost_tree", eng = "C5.0", - val = "min_n", + parsnip = "min_n", original = "minCases", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "C5.0", - val = "sample_size", + parsnip = "sample_size", original = "sample", func = list(pkg = "dials", fun = "sample_size"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -268,58 +268,58 @@ set_dependency("boost_tree", "spark", "sparklyr") set_model_arg( model = "boost_tree", eng = "spark", - val = "tree_depth", + parsnip = "tree_depth", original = "max_depth", func = list(pkg = "dials", fun = "tree_depth"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "spark", - val = "trees", + parsnip = "trees", original = "max_iter", func = list(pkg = "dials", fun = "trees"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "boost_tree", eng = "spark", - val = "learn_rate", + parsnip = "learn_rate", original = "step_size", func = list(pkg = "dials", fun = "learn_rate"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "spark", - val = "mtry", + parsnip = "mtry", original = "feature_subset_strategy", func = list(pkg = "dials", fun = "mtry"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "spark", - val = "min_n", + parsnip = "min_n", original = "min_instances_per_node", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "spark", - val = "min_info_gain", + parsnip = "min_info_gain", original = "gamma", func = list(pkg = "dials", fun = "loss_reduction"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "boost_tree", eng = "spark", - val = "sample_size", + parsnip = "sample_size", original = "subsampling_rate", func = list(pkg = "dials", fun = "sample_size"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index 307ce6b35..a8e8016e4 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -12,28 +12,28 @@ set_dependency("decision_tree", "rpart", "rpart") set_model_arg( model = "decision_tree", eng = "rpart", - val = "tree_depth", + parsnip = "tree_depth", original = "maxdepth", func = list(pkg = "dials", fun = "tree_depth"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "decision_tree", eng = "rpart", - val = "min_n", + parsnip = "min_n", original = "minsplit", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "decision_tree", eng = "rpart", - val = "cost_complexity", + parsnip = "cost_complexity", original = "cp", func = list(pkg = "dials", fun = "cost_complexity"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -140,10 +140,10 @@ set_dependency("decision_tree", "C5.0", "C5.0") set_model_arg( model = "decision_tree", eng = "C5.0", - val = "min_n", + parsnip = "min_n", original = "minCases", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -216,19 +216,19 @@ set_dependency("decision_tree", "spark", "spark") set_model_arg( model = "decision_tree", eng = "spark", - val = "tree_depth", + parsnip = "tree_depth", original = "max_depth", func = list(pkg = "dials", fun = "tree_depth"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "decision_tree", eng = "spark", - val = "min_n", + parsnip = "min_n", original = "min_instances_per_node", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 8779ad115..54dbd3338 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -105,19 +105,19 @@ set_dependency("linear_reg", "glmnet", "glmnet") set_model_arg( model = "linear_reg", eng = "glmnet", - val = "penalty", + parsnip = "penalty", original = "lambda", func = list(pkg = "dials", fun = "penalty"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "linear_reg", eng = "glmnet", - val = "mixture", + parsnip = "mixture", original = "alpha", func = list(pkg = "dials", fun = "mixture"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -290,19 +290,19 @@ set_dependency("linear_reg", "spark", "sparklyr") set_model_arg( model = "linear_reg", eng = "spark", - val = "penalty", + parsnip = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "linear_reg", eng = "spark", - val = "mixture", + parsnip = "mixture", original = "elastic_net_param", func = list(pkg = "dials", fun = "mixture"), - submodels = FALSE + has_submodel = FALSE ) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index ae581aad3..59645a4b6 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -124,19 +124,19 @@ set_dependency("logistic_reg", "glmnet", "glmnet") set_model_arg( model = "logistic_reg", eng = "glmnet", - val = "penalty", + parsnip = "penalty", original = "lambda", func = list(pkg = "dials", fun = "penalty"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "logistic_reg", eng = "glmnet", - val = "mixture", + parsnip = "mixture", original = "alpha", func = list(pkg = "dials", fun = "mixture"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -215,19 +215,19 @@ set_dependency("logistic_reg", "spark", "sparklyr") set_model_arg( model = "logistic_reg", eng = "spark", - val = "penalty", + parsnip = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "logistic_reg", eng = "spark", - val = "mixture", + parsnip = "mixture", original = "elastic_net_param", func = list(pkg = "dials", fun = "mixture"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -289,10 +289,10 @@ set_dependency("logistic_reg", "keras", "magrittr") set_model_arg( model = "logistic_reg", eng = "keras", - val = "decay", + parsnip = "decay", original = "decay", func = list(pkg = "dials", fun = "weight_decay"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/mars_data.R b/R/mars_data.R index 98dc59133..a4a84e268 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -13,26 +13,26 @@ set_dependency("mars", "earth", "earth") set_model_arg( model = "mars", eng = "earth", - val = "num_terms", + parsnip = "num_terms", original = "nprune", func = list(pkg = "dials", fun = "num_terms"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mars", eng = "earth", - val = "prod_degree", + parsnip = "prod_degree", original = "degree", func = list(pkg = "dials", fun = "prod_degree"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mars", eng = "earth", - val = "prune_method", + parsnip = "prune_method", original = "pmethod", func = list(pkg = "dials", fun = "prune_method"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/mlp_data.R b/R/mlp_data.R index a08fe51b9..4df47377a 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -14,42 +14,42 @@ set_dependency("mlp", "keras", "magrittr") set_model_arg( model = "mlp", eng = "keras", - val = "hidden_units", + parsnip = "hidden_units", original = "hidden_units", func = list(pkg = "dials", fun = "hidden_units"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mlp", eng = "keras", - val = "penalty", + parsnip = "penalty", original = "penalty", func = list(pkg = "dials", fun = "weight_decay"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mlp", eng = "keras", - val = "dropout", + parsnip = "dropout", original = "dropout", func = list(pkg = "dials", fun = "dropout"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mlp", eng = "keras", - val = "epochs", + parsnip = "epochs", original = "epochs", func = list(pkg = "dials", fun = "epochs"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mlp", eng = "keras", - val = "activation", + parsnip = "activation", original = "activation", func = list(pkg = "dials", fun = "activation"), - submodels = FALSE + has_submodel = FALSE ) @@ -178,26 +178,26 @@ set_dependency("mlp", "nnet", "nnet") set_model_arg( model = "mlp", eng = "nnet", - val = "hidden_units", + parsnip = "hidden_units", original = "size", func = list(pkg = "dials", fun = "hidden_units"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mlp", eng = "nnet", - val = "penalty", + parsnip = "penalty", original = "decay", func = list(pkg = "dials", fun = "weight_decay"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "mlp", eng = "nnet", - val = "epochs", + parsnip = "epochs", original = "maxit", func = list(pkg = "dials", fun = "epochs"), - submodels = FALSE + has_submodel = FALSE ) set_fit( model = "mlp", diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 5116d9900..b87bf022a 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -10,19 +10,19 @@ set_dependency("multinom_reg", "glmnet", "glmnet") set_model_arg( model = "multinom_reg", eng = "glmnet", - val = "penalty", + parsnip = "penalty", original = "lambda", func = list(pkg = "dials", fun = "penalty"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "multinom_reg", eng = "glmnet", - val = "mixture", + parsnip = "mixture", original = "alpha", func = list(pkg = "dials", fun = "mixture"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -101,19 +101,19 @@ set_dependency("multinom_reg", "spark", "sparklyr") set_model_arg( model = "multinom_reg", eng = "spark", - val = "penalty", + parsnip = "penalty", original = "reg_param", func = list(pkg = "dials", fun = "penalty"), - submodels = TRUE + has_submodel = TRUE ) set_model_arg( model = "multinom_reg", eng = "spark", - val = "mixture", + parsnip = "mixture", original = "elastic_net_param", func = list(pkg = "dials", fun = "mixture"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -173,10 +173,10 @@ set_dependency("multinom_reg", "keras", "magrittr") set_model_arg( model = "multinom_reg", eng = "keras", - val = "decay", + parsnip = "decay", original = "decay", func = list(pkg = "dials", fun = "weight_decay"), - submodels = FALSE + has_submodel = FALSE ) diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index dcac57748..5c5f091ec 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -13,26 +13,26 @@ set_dependency("nearest_neighbor", "kknn", "kknn") set_model_arg( model = "nearest_neighbor", eng = "kknn", - val = "neighbors", + parsnip = "neighbors", original = "ks", func = list(pkg = "dials", fun = "neighbors"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "nearest_neighbor", eng = "kknn", - val = "weight_func", + parsnip = "weight_func", original = "kernel", func = list(pkg = "dials", fun = "weight_func"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "nearest_neighbor", eng = "kknn", - val = "dist_power", + parsnip = "dist_power", original = "distance", func = list(pkg = "dials", fun = "distance"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 2f9d336c9..fd258f003 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -89,26 +89,26 @@ set_dependency("rand_forest", "ranger", "ranger") set_model_arg( model = "rand_forest", eng = "ranger", - val = "mtry", + parsnip = "mtry", original = "mtry", func = list(pkg = "dials", fun = "mtry"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "rand_forest", eng = "ranger", - val = "trees", + parsnip = "trees", original = "num.trees", func = list(pkg = "dials", fun = "trees"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "rand_forest", eng = "ranger", - val = "min_n", + parsnip = "min_n", original = "min.node.size", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -262,26 +262,26 @@ set_dependency("rand_forest", "randomForest", "randomForest") set_model_arg( model = "rand_forest", eng = "randomForest", - val = "mtry", + parsnip = "mtry", original = "mtry", func = list(pkg = "dials", fun = "mtry"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "rand_forest", eng = "randomForest", - val = "trees", + parsnip = "trees", original = "ntree", func = list(pkg = "dials", fun = "trees"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "rand_forest", eng = "randomForest", - val = "min_n", + parsnip = "min_n", original = "nodesize", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -399,26 +399,26 @@ set_dependency("rand_forest", "spark", "sparklyr") set_model_arg( model = "rand_forest", eng = "spark", - val = "mtry", + parsnip = "mtry", original = "feature_subset_strategy", func = list(pkg = "dials", fun = "mtry"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "rand_forest", eng = "spark", - val = "trees", + parsnip = "trees", original = "num_trees", func = list(pkg = "dials", fun = "trees"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "rand_forest", eng = "spark", - val = "min_n", + parsnip = "min_n", original = "min_instances_per_node", func = list(pkg = "dials", fun = "min_n"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 18868f823..e52fdb2f9 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -11,10 +11,10 @@ set_dependency("surv_reg", "flexsurv", "survival") set_model_arg( model = "surv_reg", eng = "flexsurv", - val = "dist", + parsnip = "dist", original = "dist", func = list(pkg = "dials", fun = "dist"), - submodels = FALSE + has_submodel = FALSE ) set_fit( @@ -74,10 +74,10 @@ set_dependency("surv_reg", "survival", "survival") set_model_arg( model = "surv_reg", eng = "survival", - val = "dist", + parsnip = "dist", original = "dist", func = list(pkg = "dials", fun = "dist"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index 8fc2da66d..565c52d87 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -12,36 +12,36 @@ set_dependency("svm_poly", "kernlab", "kernlab") set_model_arg( model = "svm_poly", eng = "kernlab", - val = "cost", + parsnip = "cost", original = "C", func = list(pkg = "dials", fun = "cost"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "svm_poly", eng = "kernlab", - val = "degree", + parsnip = "degree", original = "degree", func = list(pkg = "dials", fun = "degree"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "svm_poly", eng = "kernlab", - val = "scale_factor", + parsnip = "scale_factor", original = "scale", func = list(pkg = "dials", fun = "scale_factor"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "svm_poly", eng = "kernlab", - val = "margin", + parsnip = "margin", original = "epsilon", func = list(pkg = "dials", fun = "margin"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index d489222ae..3c905e856 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -12,28 +12,28 @@ set_dependency("svm_rbf", "kernlab", "kernlab") set_model_arg( model = "svm_rbf", eng = "kernlab", - val = "cost", + parsnip = "cost", original = "C", func = list(pkg = "dials", fun = "cost"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "svm_rbf", eng = "kernlab", - val = "rbf_sigma", + parsnip = "rbf_sigma", original = "sigma", func = list(pkg = "dials", fun = "rbf_sigma"), - submodels = FALSE + has_submodel = FALSE ) set_model_arg( model = "svm_rbf", eng = "kernlab", - val = "margin", + parsnip = "margin", original = "epsilon", func = list(pkg = "dials", fun = "margin"), - submodels = FALSE + has_submodel = FALSE ) set_fit( diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd index 1b4690f40..e46103fc1 100644 --- a/man/check_mod_val.Rd +++ b/man/check_mod_val.Rd @@ -1,6 +1,8 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/aaa_models.R -\name{check_mod_val} +\docType{data} +\name{pred_types} +\alias{pred_types} \alias{check_mod_val} \alias{check_mode_val} \alias{check_engine_val} @@ -11,7 +13,10 @@ \alias{check_pred_info} \alias{check_pkg_val} \title{Tools to Check Model Elements} +\format{An object of class \code{character} of length 8.} \usage{ +pred_types + check_mod_val(model, new = FALSE, existence = FALSE) check_mode_val(mode) @@ -20,15 +25,15 @@ check_engine_val(eng) check_arg_val(arg) -check_submodels_val(x) +check_submodels_val(has_submodel) check_func_val(func) -check_fit_info(x) +check_fit_info(fit_obj) -check_pred_info(x, type) +check_pred_info(pred_obj, type) -check_pkg_val(x) +check_pkg_val(pkg) } \arguments{ \item{model}{A single character string for the model type (e.g. @@ -45,6 +50,39 @@ been registered.} \item{eng}{A single character string for the model engine.} \item{arg}{A single character string for the model argument name.} + +\item{has_submodel}{A single logical for whether the argument +can make predictions on mutiple submodels at once.} + +\item{func}{A named character vector that describes how to call +a function. \code{func} should have elements \code{pkg} and \code{fun}. The +former is optional but is recommended and the latter is +required. For example, \code{c(pkg = "stats", fun = "lm")} would be +used to invoke the usual linear regression function. In some +cases, it is helpful to use \code{c(fun = "predict")} when using a +package's \code{predict} method.} + +\item{fit_obj}{A list with elements \code{interface}, \code{protect}, +\code{func} and \code{defaults}. See the package vignette "Making a +\code{parsnip} model from scratch".} + +\item{pred_obj}{A list with elements \code{pre}, \code{post}, \code{func}, and \code{args}. +See the package vignette "Making a \code{parsnip} model from scratch".} + +\item{type}{A single character value for the type of prediction. Possible +values are: +\Sexpr[results=rd]{paste0("'", parsnip::pred_types, "'", collapse = ", ")}.} + +\item{pkg}{An options character string for a package name.} + +\item{parsnip}{A single character string for the "harmonized" argument name +that \code{parsnip} exposes.} + +\item{original}{A single character string for the argument name that +underlying model function uses.} + +\item{value}{A list that conforms to the \code{fit_obj} or \code{pred_obj} description +above, depending on context.} } \description{ These functions are similar to constructors and can be used to validate diff --git a/man/get_model_env.Rd b/man/get_model_env.Rd index 978dff676..76ff615f2 100644 --- a/man/get_model_env.Rd +++ b/man/get_model_env.Rd @@ -23,7 +23,7 @@ set_model_mode(model, mode) set_model_engine(model, mode, eng) -set_model_arg(model, eng, val, original, func, submodels) +set_model_arg(model, eng, parsnip, original, func, has_submodel) set_dependency(model, eng, pkg) From d0aaee9b8358399fcd14da2cd016a79407003535 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 May 2019 11:11:31 -0700 Subject: [PATCH 25/64] better model summary code and examples --- R/aaa_models.R | 51 ++++++++++++++++++++++++++++++++++++-------- man/check_mod_val.Rd | 10 +++++++++ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 50c9a0161..665c4ed7b 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -90,9 +90,18 @@ get_model_env <- function() { #' @param value A list that conforms to the `fit_obj` or `pred_obj` description #' above, depending on context. #' @keywords internal +#' @examples +#' # Show the infomration about a model: +#' show_model_info("rand_forest") +#' +#' # Access the model data: +#' +#' current_code <- get_model_env() +#' ls(envir = current_code) +#' #' @export check_mod_val <- function(model, new = FALSE, existence = FALSE) { - if (is_missing(model) || length(model) != 1) + if (rlang::is_missing(model) || length(model) != 1) stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", call. = FALSE) @@ -120,7 +129,7 @@ check_mod_val <- function(model, new = FALSE, existence = FALSE) { #' @keywords internal #' @export check_mode_val <- function(mode) { - if (is_missing(mode) || length(mode) != 1) + if (rlang::is_missing(mode) || length(mode) != 1) stop("Please supply a character string for a mode (e.g. `'regression'`)", call. = FALSE) invisible(NULL) @@ -130,7 +139,7 @@ check_mode_val <- function(mode) { #' @keywords internal #' @export check_engine_val <- function(eng) { - if (is_missing(eng) || length(eng) != 1) + if (rlang::is_missing(eng) || length(eng) != 1) stop("Please supply a character string for an engine (e.g. `'lm'`)", call. = FALSE) invisible(NULL) @@ -140,7 +149,7 @@ check_engine_val <- function(eng) { #' @keywords internal #' @export check_arg_val <- function(arg) { - if (is_missing(arg) || length(arg) != 1) + if (rlang::is_missing(arg) || length(arg) != 1) stop("Please supply a character string for the argument", call. = FALSE) invisible(NULL) @@ -166,7 +175,7 @@ check_func_val <- function(func) { "element 'pkg'. These should both be single character strings." ) - if (is_missing(func) || !is.vector(func) || length(func) > 2) + if (rlang::is_missing(func) || !is.vector(func) || length(func) > 2) stop(msg, call. = FALSE) nms <- sort(names(func)) @@ -266,7 +275,7 @@ check_pred_info <- function(pred_obj, type) { #' @keywords internal #' @export check_pkg_val <- function(pkg) { - if (is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) + if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) stop("Please supply a single character vale for the package name", call. = FALSE) invisible(NULL) @@ -609,7 +618,7 @@ show_model_info <- function(model) { cat( " modes:", paste0(current[[paste0(model, "_modes")]], collapse = ", "), - "\n" + "\n\n" ) engines <- current[[paste0(model)]] @@ -629,6 +638,7 @@ show_model_info <- function(model) { dplyr::ungroup() %>% dplyr::pull(lab) %>% cat(sep = "") + cat("\n") } else { cat(" no registered engines yet.") } @@ -652,17 +662,40 @@ show_model_info <- function(model) { dplyr::ungroup() %>% dplyr::pull(lab) %>% cat(sep = "") + cat("\n") } else { cat(" no registered arguments yet.") } - fits <- current[[paste0(model, "_fits")]] + fits <- current[[paste0(model, "_fit")]] if (nrow(fits) > 0) { - + cat(" fit modules:\n") + fits %>% + dplyr::select(-value) %>% + mutate(engine = paste0(" ", engine)) %>% + as.data.frame() %>% + print(row.names = FALSE) + cat("\n") } else { cat(" no registered fit modules yet.") } + preds <- current[[paste0(model, "_predict")]] + if (nrow(preds) > 0) { + cat(" prediction modules:\n") + preds %>% + dplyr::group_by(mode, engine) %>% + dplyr::summarize(methods = paste0(sort(type), collapse = ", ")) %>% + dplyr::ungroup() %>% + mutate(mode = paste0(" ", mode)) %>% + as.data.frame() %>% + print(row.names = FALSE) + cat("\n") + } else { + cat(" no registered prediction modules yet.") + } + + invisible(NULL) } diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd index e46103fc1..f998791a2 100644 --- a/man/check_mod_val.Rd +++ b/man/check_mod_val.Rd @@ -88,5 +88,15 @@ above, depending on context.} These functions are similar to constructors and can be used to validate that there are no conflicts with the underlying model structures used by the package. +} +\examples{ +# Show the infomration about a model: +show_model_info("rand_forest") + +# Access the model data: + +current_code <- get_model_env() +ls(envir = current_code) + } \keyword{internal} From 75f5c47a3b9673d419192f3754a92ef32935c20b Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 May 2019 11:23:04 -0700 Subject: [PATCH 26/64] a little more documentation --- R/aaa_models.R | 15 ++++++++++++++- man/check_mod_val.Rd | 19 ++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/R/aaa_models.R b/R/aaa_models.R index 665c4ed7b..546419f0f 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -90,8 +90,21 @@ get_model_env <- function() { #' @param value A list that conforms to the `fit_obj` or `pred_obj` description #' above, depending on context. #' @keywords internal +#' @details These functions are available for users to add their +#' own models or engines (in package or otherwise) so that they can +#' be accessed using `parsnip`. This are more thoroughly documented +#' on the package web site (see references below). +#' +#' In short, `parsnip` stores an environment object that contains +#' all of the information and code about how models are used (e.g. +#' fitting, predicting, etc). These functions can be used to add +#' models to that environment as well as helper functions that can +#' be used to makes sure that the model data is in the right +#' format. +#' @references "Making a parsnip model from scratch" +#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html} #' @examples -#' # Show the infomration about a model: +#' # Show the information about a model: #' show_model_info("rand_forest") #' #' # Access the model data: diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd index f998791a2..7779985dc 100644 --- a/man/check_mod_val.Rd +++ b/man/check_mod_val.Rd @@ -89,8 +89,21 @@ These functions are similar to constructors and can be used to validate that there are no conflicts with the underlying model structures used by the package. } +\details{ +These functions are available for users to add their +own models or engines (in package or otherwise) so that they can +be accessed using \code{parsnip}. This are more thoroughly documented +on the package web site (see references below). + +In short, \code{parsnip} stores an environment object that contains +all of the information and code about how models are used (e.g. +fitting, predicting, etc). These functions can be used to add +models to that environment as well as helper functions that can +be used to makes sure that the model data is in the right +format. +} \examples{ -# Show the infomration about a model: +# Show the information about a model: show_model_info("rand_forest") # Access the model data: @@ -98,5 +111,9 @@ show_model_info("rand_forest") current_code <- get_model_env() ls(envir = current_code) +} +\references{ +"Making a parsnip model from scratch" +\url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html} } \keyword{internal} From ef96a468faf6053a27d928bea1247b1ae103cd59 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 May 2019 20:57:04 -0700 Subject: [PATCH 27/64] updated vignette for custom models --- DESCRIPTION | 2 +- NEWS.md | 4 +- R/aaa_models.R | 9 +- vignettes/articles/Scratch.Rmd | 282 ++++++++++++++++----------------- 4 files changed, 145 insertions(+), 152 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 8de5c838c..55691537e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: parsnip -Version: 0.0.2 +Version: 0.0.2.9000 Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc). Authors@R: c( diff --git a/NEWS.md b/NEWS.md index 47461648d..b135f4295 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,9 +2,9 @@ ## Breaking Changes - * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. + * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html). * The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation). - * For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`. + * For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`. ## New Features diff --git a/R/aaa_models.R b/R/aaa_models.R index 546419f0f..a2cf0fde8 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -108,7 +108,6 @@ get_model_env <- function() { #' show_model_info("rand_forest") #' #' # Access the model data: -#' #' current_code <- get_model_env() #' ls(envir = current_code) #' @@ -653,7 +652,7 @@ show_model_info <- function(model) { cat(sep = "") cat("\n") } else { - cat(" no registered engines yet.") + cat(" no registered engines.\n\n") } args <- current[[paste0(model, "_args")]] @@ -677,7 +676,7 @@ show_model_info <- function(model) { cat(sep = "") cat("\n") } else { - cat(" no registered arguments yet.") + cat(" no registered arguments.\n\n") } fits <- current[[paste0(model, "_fit")]] @@ -690,7 +689,7 @@ show_model_info <- function(model) { print(row.names = FALSE) cat("\n") } else { - cat(" no registered fit modules yet.") + cat(" no registered fit modules.\n\n") } preds <- current[[paste0(model, "_predict")]] @@ -705,7 +704,7 @@ show_model_info <- function(model) { print(row.names = FALSE) cat("\n") } else { - cat(" no registered prediction modules yet.") + cat(" no registered prediction modules.\n\n") } diff --git a/vignettes/articles/Scratch.Rmd b/vignettes/articles/Scratch.Rmd index ecccb90d6..81e0282e3 100644 --- a/vignettes/articles/Scratch.Rmd +++ b/vignettes/articles/Scratch.Rmd @@ -1,8 +1,8 @@ --- -title: Making a `parsnip` model from scratch +title: Making a parsnip model from scratch vignette: > %\VignetteEngine{knitr::rmarkdown} - %\VignetteIndexEntry{Making a `parsnip` model from scratch} + %\VignetteIndexEntry{Making a parsnip model from scratch} output: knitr:::html_vignette: toc: yes @@ -28,7 +28,7 @@ library(mda) * It eliminates a lot of duplicate code. * Since the expressions are not evaluated until fitting, it eliminates a large amount of package dependencies. -A `parsnip` model function is itself very general. For example, the `logistic_reg` function itself doesn't have any model code within it. Instead, each model function is associated with one or more computational _engines_. These might be different R packages or some function in another language (that can be evaluated by R). +A `parsnip` model function is itself very general. For example, the `logistic_reg()` function itself doesn't have any model code within it. Instead, each model function is associated with one or more computational _engines_. These might be different R packages or some function in another language (that can be evaluated by R). This vignette describes the process of creating a new model function. Before proceeding, take a minute and read our [guidelines on creating modeling packages](https://tidymodels.github.io/model-implementation-principles/) to get the general themes and conventions that we use. @@ -40,59 +40,76 @@ str(mda::mda) The main hyper-parameter is the number of subclasses. We'll name our function `mixture_da`. -## Step 1. Make the objects for the general method +## Aspects of Models -There are three objects that define the parameters and other characteristics of the model function. +Before proceeding, it helps to to review how `parsnip` categorizes models: -First, is the object that describes the model's mode(s). The modes are the type of model and the two main values are "classification" and "regression". A third mode, "unknown", is used for initializing objects but models will fail if it is used further. +* The model _type_ is related to the structural aspect of the model. For example, the model type `linear_reg` represents linear models (slopes and intercepts) that model a numeric outcome. Other model types in the package are `neighest_nighbors`, decision_tree`, and so on. -The convention in `parsnip` is to use the name `{model name}_modes`. In our case, we have: +* Within a model type is the _mode_. This relates to the modeling goal. Currently the two modes in the package are "regression" and "classification". Some models have methods for both models (e.g. nearest neighbors) while others are specific to a single mode (e.g. logistic regression). -```{r modes} -mixture_da_modes <- c("classification", "unknown") -``` +* The computation _engine_ is a combination of the estimation method and the implementation. For example, for linear regression, one model is `"lm"` and this uses ordinal least squares analysis using the `lm` package. Another engine is `"stan"` which uses the Stan infrastructure to estimate parameters using Bayes rule. -Next, we define the engines used by the model and the associated mode. Here, the columns correspond to the engine names and rows are the modes (via row names). We have two engines and one effective mode, so our object will have the suffix `_engines`: +When adding a model into `parsnip`, the user has to specific which modes and engines are used. The package also enables users to add a new mode or engine to an existing model. -```{r engines} -mixture_da_engines <- data.frame( - mda = TRUE, - row.names = c("classification") -) -mixture_da_engines -``` +## The General Process -A row for "unknown" modes is not needed in this object. +`parsnip` stores information about the models in an internal environment object. The environment can be accessed via the function `get_model_env()`. The package includes a variety of functions that can get or set the different aspects of the models. -Now, we enumerate the _main arguments_ for each engine. `parsnip` standardizes the names of arguments across different models and engines. For example, random forest and boosting use multiple trees to create the ensemble. Instead of using different argument names, `parsnip` standardizes on `trees` and the underlying code translates to the actual arguments used by the different functions. +If you are adding a new model form your own package, you can use these functions to add new entries into the model environment. -In our case, the MDA argument name will be "sub_classes". +## Step 1. Register the Model, Modes, and Arguments. -Here, the object name will have the suffix `_arg_key` and will have columns for the engines and rows for the arguments. The entries for the data frame are the actual arguments for each engine (and is `NA` when an engine doesn't have that argument). Ours: +We will add the MDA model using the model type `mixture_da`. Since this is a classification method, we only have to register a single mode: -```{r arg-key} -mixture_da_arg_key <- data.frame( - mda = "sub_classes", - row.names = "sub_classes", - stringsAsFactors = FALSE +```{r mda-reg} +library(parsnip) +set_new_model("mixture_da") +set_model_mode(model = "mixture_da", mode = "classification") +set_model_engine( + "mixture_da", + mode = "classification", + eng = "mda" ) ``` -As an example of a model with multiple engines, here is the object for logistic regression: +These functions should silently finish. There is also a function that can be used to show what aspects of the model have been added to `parsnip`: -```{r lr-key} -parsnip:::logistic_reg_arg_key +```{r mda-show-1} +show_model_info("mixture_da") ``` -The internals of `parsnip` will use these objects during the creation of the model code. +The next step would be the declare the main arguments to the model. These are declared indpendnent of the mode. To specifiy the argument, there are a few slots to fill in: + + * The name that `parsnip` uses for the argument. In general, we try to use non-jargony names for arguments (e.g. "penalty" instead of "lambda" for regularised regression). We recommend consulting [this page]() to see if an existing argument name can be used before creating a new one. + + * The argument name that is used by the underlying modeling function. + + * A function reference for a _contructor_ that will be used to generate tuning parameter values. This should be a character vector that has a named element called `fun` that is the contructor function. There is an optional element `pkg` that can be used to call the function using its namespace. + + * A logical value for wether the argument can be used to generate mutiple predictions for a single R object. For example, for boosted trees, if a model is fit with 10 boosting iterations, many modeling packages allow the model object to make predictions for any iterations less than the one used to fit the model. In general this is not the case so one would use `has_submodels = FALSE`. + +For `mda::mda()`, the main tuning parameter is `subclasses` which we will rewrite as `sub_classes`. + +```{r mda-args} +set_model_arg( + model = "mixture_da", + eng = "mda", + parsnip = "sub_classes", + original = "subclasses", + func = list(pkg = "foo", fun = "bar"), + has_submodel = FALSE +) +show_model_info("mixture_da") +``` -## Step 2. Create the model function +## Step 3. Create the model function This is a fairly simple function that can follow a basic template. The main arguments to our function will be: - * The mode. If the model can do more than one mode, you might default this to "unknown". In our case, since it is only a classification model, it makes sense to default it to that mode. + * The mode. If the model can do more than one mode, you might default this to "unknown". In our case, since it is only a classification model, it makes sense to default it to that mode so that the users won't have to specify it. + * The argument names (`sub_classes` here). These should be defaulted to `NULL`. - * `...` are _not_ used in the main model function. A basic version of the function is: @@ -100,10 +117,9 @@ A basic version of the function is: mixture_da <- function(mode = "classification", sub_classes = NULL) { # Check for correct mode - if (!(mode %in% mixture_da_modes)) - stop("`mode` should be one of: ", - paste0("'", mixture_da_modes, "'", collapse = ", "), - call. = FALSE) + if (mode != "classification") { + stop("`mode` should be 'classification'", call. = FALSE) + } # Capture the arguments in quosures args <- list(sub_classes = rlang::enquo(sub_classes)) @@ -120,62 +136,50 @@ mixture_da <- This is pretty simple since the data are not exposed to this function. -## Step 3. Make the model object - -This is where the details of the models are specified. This will be a list that has a few different elements: - - * `libs` is a character string that has any package names that will be required for the model fit. - * `fit` has details for the model fit function. - * `pred`, `prob`, and `classes`. These are lists of details for making predictions on numbers, class probabilities, or hard class predictions (respectively). - -We'll look at each. The convention here is to name this `{model name}_{engine}_data`. We'll start with: - -```{r mda-start} -mixture_da_mda_data <- list(libs = "mda") -``` +## Step 3. Add a Fit Module -### The `fit` module +Now that `parsnip` knows about the model, mode, and engine, we can give it the infomration on fitting the model for our engine. +In infomration needed to fit the model is contained in another list. The elements are: -The main arguments are: - * `interface` a single character value that could be "formula", "data.frame", or "matrix". This defines the type of interface used by the underlying fit function (`mda::mda`, in this case). This helps the translation of the data to be in an appropriate format for the that function. + * `protect` is an optional list of function arguments that **should not be changeable** by the user. In this case, we probably don't want users to pass data values to these arguments (until the `fit` function is called). + * `func` is the package and name of the function that will be called. If you are using a locally defined function, only `fun` is required. + * `defaults` is an optional list of arguments to the fit function that the user can change, but whose defaults can be set here. This isn't needed in this case, but is describe later in this document. For the first engine: ```{r fit-mod} -mixture_da_mda_data$fit <- - list( +set_fit( + model = "mixture_da", + eng = "mda", + mode = "classification", + value = list( interface = "formula", - protect = c("formula", "data", "weights"), + protect = c("formula", "data"), func = c(pkg = "mda", fun = "mda"), defaults = list() ) +) +show_model_info("mixture_da") ``` -### The `numeric` module - -This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For `numeric`, the model requires an unnamed numeric vector output (usually). - -Examples are [here](https://github.com/topepo/parsnip/blob/master/R/linear_reg_data.R) and [here](https://github.com/topepo/parsnip/blob/master/R/rand_forest_data.R). - -For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results). - -Note that the `numeric` module maps to the `predict_numeric` function in `parsnip`. However, the user-facing `predict` function is used to generate predictions and returns a tibble with a column named `.pred` (see the example below). When creating new models, you don't have to write code for that part. +## Step 3. Add Modules for Prediction - -### The `class` module - -To make hard class predictions, the `class` object contains the details. The elements of the list are: +Similar to the fitting module, we specify the code for making different types of predictions. To make hard class predictions, the `class` object contains the details. The elements of the list are: * `pre` and `post` are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won't be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a `pre` function with the `args` below. * `func` is the prediction function (in the same format as above). In many cases, packages have a predict method for their model's class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to `predict` with no associated package. - * `args` is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in `rlang::expr` so that they are not evaluated when defining the method. For `mda`, the code would be `predict(object, newdata, type = "class")`. What is actually given to the function is the `parsnip` model fit object, which includes a sub-object called `fit` and this houses the `mda` model object. If the data need to be a matrix or data frame, you could also use `new_data = quote(as.data.frame(new_data))` and so on. + * `args` is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in `rlang::expr` so that they are not evaluated when defining the method. For `mda`, the code would be `predict(object, newdata, type = "class")`. What is actually given to the function is the `parsnip` model fit object, which includes a sub-object called `fit` and this houses the `mda` model object. If the data need to be a matrix or data frame, you could also use `newdata = quote(as.data.frame(newdata))` and so on. + +The `parsnip` prediction code will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. + +To add this method to the model environment, a similar `set` function is used: -```{r mda-class} -mixture_da_mda_data$class <- +```{r mds-class} +class_info <- list( pre = NULL, post = NULL, @@ -192,18 +196,24 @@ mixture_da_mda_data$class <- type = "class" ) ) -``` - -The `predict_class` function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the `pred` module, the user doesn't call `predict_class` but uses `predict` instead and this produces a tibble with a column named `.pred_class` [per the model guidlines](https://tidymodels.github.io/model-implementation-principles/model-predictions.html#return-values). -### The `classprob` module +set_pred( + model = "mixture_da", + eng = "mda", + mode = "classification", + type = "class", + value = class_info +) +``` -This defines the class probabilities (if they can be computed). The format is identical to the `class` module but the output is expected to be a tibble with columns for each factor level. +A similar call can be used to define the class probability module (if they can be computed). The format is identical to the `class` module but the output is expected to be a tibble with columns for each factor level. As an example of the `post` function, the data frame created by `mda:::predict.mda` will be converted to a tibble. The arguments are `x` (the raw results coming from the predict method) and `object` (the `parsnip` model fit object). The latter has a sub-object called `lvl` which is a character string of the outcome's factor levels (if any). -```{r mda-classprob} -mixture_da_mda_data$classprob <- +We register the probability module: + +```{r mda-prob} +prob_info <- list( pre = NULL, post = function(x, object) { @@ -217,9 +227,22 @@ mixture_da_mda_data$classprob <- type = "posterior" ) ) + +set_pred( + model = "mixture_da", + eng = "mda", + mode = "classification", + type = "prob", + value = prob_info +) + +show_model_info("mixture_da") ``` -The `post` element converts the output to a tibble but the main `predict` method does proper naming of the column names. +If this model could be used for regression situations, we could also add a "numeric" module. The convention used here is very similar to the two that are detailed in the next section. For `pred`, the model requires an unnamed numeric vector output (usually). + +Examples are [here](https://github.com/topepo/parsnip/blob/master/R/linear_reg_data.R) and [here](https://github.com/topepo/parsnip/blob/master/R/rand_forest_data.R). + ## Does it Work? @@ -228,98 +251,81 @@ As a developer, one thing that may come in handy is the `translate` function. Th For example: ```{r mda-code} -library(tidymodels) +library(parsnip) +library(tidyverse) mixture_da(sub_classes = 2) %>% - set_engine("mda") %>% - translate() + translate(engine = "mda") ``` Let's try it on the iris data: ```{r mda-data} +library(rsample) +library(tibble) + set.seed(4622) -iris_split <- initial_split(iris, prop = 0.90) +iris_split <- initial_split(iris, prop = 0.95) iris_train <- training(iris_split) iris_test <- testing(iris_split) mda_spec <- mixture_da(sub_classes = 2) mda_fit <- mda_spec %>% - set_engine("mda") %>% - fit(Species ~ ., data = iris_train) + fit(Species ~ ., data = iris_train, engine = "mda") mda_fit -predict(mda_fit, new_data = iris_test) %>% - bind_cols(iris_test %>% select(Species)) +predict(mda_fit, new_data = iris_test, type = "prob") %>% + mutate(Species = iris_test$Species) -predict(mda_fit, new_data = iris_test, type = "prob") %>% - bind_cols(iris_test %>% select(Species)) +predict(mda_fit, new_data = iris_test) %>% + mutate(Species = iris_test$Species) ``` # Pro-tips, what-ifs, exceptions, FAQ, and minutiae There are various things that came to mind while writing this document. -### Do I have to return a simple vector for `predict_num` and `predict_class`? - -Previously, when discussing the `numeric` information: +### Do I have to return a simple vector for `predict` and `predict_class`? -> For `numeric`, the model requires an unnamed numeric vector output **(usually)**. +Previously, when discussing the `pred` information: -There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated [here](https://tidymodels.github.io/model-implementation-principles/notes.html#list-cols) but some easy examples are: +> For `pred`, the model requires an unnamed numeric vector output **(usually)**. - * confidence or prediction intervals - * quantile regression predictions. - -and so on. These can be accomodated via `predict.model_fit` using different `type` arguments. +There are some models (e.g. `glmnet`, `plsr`, `Cubust`, etc.) that can make predictions for different models from the same fitted model object. We want to facilitate that here so that, for these cases, the current convention is to return a tibble with the prediction in a column called `values` and have extra columns for any parameters that define the different sub-models. -However, there are some models (e.g. `glmnet`, `plsr`, `Cubist`, etc.) that can make predictions for different models from the same fitted model object. The regular `predict` method requires prediction from a single model but the `multi_predict` can. The guideline is to _always return the same number of rows as in `new_data`_. This means that the `.pred` column is a list-column of tibbles. +For example, if I fit a linear regression model via `glmnet` and get four values of the regularization parameter (`lambda`): -For example, for a multinomial `glmnet` model, we leave `penalty` unspecified when fitting and get predictions on a sequence of values: - -```{r mnom-glmnet-fit} -mod <- multinom_reg(mixture = 1/3) %>% - set_engine("glmnet") -mod_fit <- fit(mod, Species ~ ., data = iris) - -preds <- multi_predict(mod_fit, iris[1:3, -5], penalty = c(0, 0.01, 0.1), type = "prob") -preds -preds[[".pred"]][1] -``` -This can be easily expanded to remove the list columns: - -```{r mnom-glmnet-expand} -preds %>% - mutate(.row = 1:nrow(preds)) %>% - tidyr::unnest() +```{r glmnet, eval = FALSE} +linear_reg(others = list(nlambda = 4)) %>% + fit(mpg ~ ., data = mtcars, engine = "glmnet") %>% + predict(newdata = mtcars[1:3, -1]) ``` -`multi_predict` doesn't exist for every model and needs to be implmented by the developer. See `methods("multi_predict")` for examples in this package. +_However_, the api is still being developed. Currently, there is not an interface in the prediction functions to pass in the values of the parameters to make predictions with (`lambda`, in this case). -### What is the `defaults` slot and why do I need it? +### What is `defaults` slot and why do I need them? You might want to set defaults that can be overridden by the user. For example, for logistic regression with `glm`, it make sense to default `family = binomial`. However, if someone wants to use a different link function, they should be able to do that. For that model/engine definition, it has ```{r glm-alt, eval = FALSE} -defaults = list(family = expr(stats::binomial)) +defaults = list(family = expr(binomial)) ``` so that is the default: -```{r glm-alt-show} +```{r glm-alt-show, eval = FALSE} logistic_reg() %>% translate(engine = "glm") # but you can change it: -logistic_reg() %>% - set_engine("glm", family = stats::binomial(link = "probit")) %>% - translate() +logistic_reg(others = list(family = expr(binomial(link = "probit")))) %>% + translate(engine = "glm") ``` That's what `defaults` are for. -Note that I wrapped `binomial` inside of `expr`. If I didn't, it would substitute the results of executing `binomial` inside of the expression (and that's a mess). Using namespaces is a good idea here. +Note that I wrapped `binomial` inside of `expr`. If I didn't, it would substitute the results of executing `binomial` inside of the expression (and that's a mess). ### What if I want more complex defaults? @@ -329,39 +335,27 @@ For example, the `ranger` and `randomForest` package functions have arguments fo ```{r rf-trans, eval = FALSE} # Simplified version -translate.rand_forest <- function (x, engine = x$engine, ...){ +translate.rand_forest <- function (x, engine, ...){ # Run the general method to get the real arguments in place x <- translate.default(x, engine, ...) - # Make code easier to read - arg_vals <- x$method$fit$args - # Check and see if they make sense for the engine and/or mode: - if (engine == "ranger") { - if (any(names(arg_vals) == "importance")) - # We want to check the type of `importance` but it is a quosure. We first - # get the expression. It is is logical, the value of `quo_get_expr` will - # not be an expression but the actual logical. The wrapping of `isTRUE` - # is there in case it is not an atomic value. - if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) + if (x$engine == "ranger") { + if (any(names(x$method$fit$args) == "importance")) + if (is.logical(x$method$fit$args$importance)) stop("`importance` should be a character value. See ?ranger::ranger.", call. = FALSE) - if (x$mode == "classification" && !any(names(arg_vals) == "probability")) - arg_vals$probability <- TRUE } - x$method$fit$args <- arg_vals x } ``` As another example, `nnet::nnet` has an option for the final layer to be linear (called `linout`). If `mode = "regression"`, that should probably be set to `TRUE`. You couldn't do this with the `args` (described above) since you need the function translated first. -In cases where the model requires different defaults, the `translate` method can also be used. See the code for the `mars` function to see how to check and potentially switch arguments for classification models. - ### My model fit requires more than one function call. So....? -The best course of action is to write wrapper so that it can be one call. This was the case with `xgboost`, `C5.0`, and `keras`. +The best course of action is to write wrapper so that it can be one call. This was the case with `xgboost` and `keras`. ### Why would I preprocess my data? @@ -372,7 +366,7 @@ This would **not** include making dummy variables and `model.matrix` stuff. `par ### Why would I postprocess my predictions? -What comes back from some R functions make be somewhat... arcane or problematic. As an example, for `xgboost`, if you fit a multiclass boosted tree, you might expect the class probabilities to come back as a matrix^[_narrator_: they don't]. If you have four classes and make predictions on three samples, you get a vector of 12 probability values. You need to convert these to a rectangular data set. +What comes back from some R functions make be somewhat... arcane or problematic. As an example, for `xgboost`, if you fit a multiclass boosted tree, you might expect the class probabilities to come back as a matrix (narrator: they don't). If you have four classes and make predictions on three samples, you get a vector of 12 probability values. You need to convert these to a rectangular data set. Another example is the predict method for `ranger`, which encapsulates the actual predictions in a more complex object structure. From 1113ed52b9e64857722bfd7e7c2d77c0a75e0d8a Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 May 2019 20:59:47 -0700 Subject: [PATCH 28/64] more banners --- README.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a4253bf86..a095fe62d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,12 @@ -[![Travis build status](https://travis-ci.org/tidymodels/parsnip.svg?branch=master)](https://travis-ci.org/tidymodels/parsnip) -[![Coverage status](https://codecov.io/gh/tidymodels/parsnip/branch/master/graph/badge.svg)](https://codecov.io/github/tidymodels/parsnip?branch=master) -![](https://img.shields.io/badge/lifecycle-maturing-blue.svg) +[![Build +Status](https://travis-ci.org/tidymodels/parsnip.svg?branch=master)](https://travis-ci.org/tidymodels/parsnip) +[![Coverage +status](https://codecov.io/gh/tidymodels/parsnip/branch/master/graph/badge.svg)](https://codecov.io/github/tidymodels/parsnip?branch=master) +[![CRAN\_Status\_Badge](http://www.r-pkg.org/badges/version/parsnip)](http://cran.r-project.org/web/packages/parsnip) +[![Downloads](http://cranlogs.r-pkg.org/badges/parsnip)](http://cran.rstudio.com/package=parsnip) +[![lifecycle](https://img.shields.io/badge/lifecycle-maturing-blue.svg)](https://www.tidyverse.org/lifecycle/#maturing) + One issue with different functions available in R _that do the same thing_ is that they can have different interfaces and arguments. For example, to fit a random forest _classification_ model, we might have: From eab824e8d52c9e52428338b71afe71e23d223472 Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 18:46:43 -0400 Subject: [PATCH 29/64] added a translate function for knn --- NAMESPACE | 1 + R/nearest_neighbor.R | 26 +++++++++++++-- R/nearest_neighbor_data.R | 4 +-- R/rand_forest_data.R | 36 +++++++++++++++++++++ man/nearest_neighbor.Rd | 3 +- tests/testthat/test_nearest_neighbor.R | 10 +++--- tests/testthat/test_nearest_neighbor_kknn.R | 1 + 7 files changed, 70 insertions(+), 11 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index ca7a96068..a1d6047e4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -45,6 +45,7 @@ S3method(translate,decision_tree) S3method(translate,default) S3method(translate,mars) S3method(translate,mlp) +S3method(translate,nearest_neighbor) S3method(translate,rand_forest) S3method(translate,surv_reg) S3method(translate,svm_poly) diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index eaa83188e..6cb9b07d1 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -29,7 +29,8 @@ #' `"classification"`. #' #' @param neighbors A single integer for the number of neighbors -#' to consider (often called `k`). +#' to consider (often called `k`). For \pkg{kknn}, a value of 5 +#' is used if `neighbors` is not specified. #' #' @param weight_func A *single* character for the type of kernel function used #' to weight distances between samples. Valid choices are: `"rectangular"`, @@ -148,13 +149,32 @@ check_args.nearest_neighbor <- function(object) { args <- lapply(object$args, rlang::eval_tidy) - if(is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { + if (is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { stop("`neighbors` must be a length 1 positive integer.", call. = FALSE) } - if(is.character(args$weight_func) && length(args$weight_func) > 1) { + if (is.character(args$weight_func) && length(args$weight_func) > 1) { stop("The length of `weight_func` must be 1.", call. = FALSE) } invisible(object) } + +# ------------------------------------------------------------------------------ + +#' @export +translate.nearest_neighbor <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'kknn'` for translation.") + engine <- "kknn" + } + x <- translate.default(x, engine, ...) + + if (engine == "kknn") { + if (!any(names(x$method$fit$args) == "ks") || + is_missing_arg(x$method$fit$args$ks)) { + x$method$fit$args$ks <- 5 + } + } + x +} diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 3c8edb10a..b9fe886cb 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -41,7 +41,7 @@ set_fit( mode = "regression", value = list( interface = "formula", - protect = c("formula", "data", "ks"), + protect = c("formula", "data"), func = c(pkg = "kknn", fun = "train.kknn"), defaults = list() ) @@ -53,7 +53,7 @@ set_fit( mode = "classification", value = list( interface = "formula", - protect = c("formula", "data", "ks"), + protect = c("formula", "data"), func = c(pkg = "kknn", fun = "train.kknn"), defaults = list() ) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 64e47b927..7e3cd4462 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -195,6 +195,24 @@ set_pred( ) ) +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "ranger_confint"), + args = + list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) + set_pred( mod = "rand_forest", eng = "ranger", @@ -234,6 +252,24 @@ set_pred( ) ) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "ranger_confint"), + args = + list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) set_pred( mod = "rand_forest", eng = "ranger", diff --git a/man/nearest_neighbor.Rd b/man/nearest_neighbor.Rd index 123d8f9ca..133ac7161 100644 --- a/man/nearest_neighbor.Rd +++ b/man/nearest_neighbor.Rd @@ -13,7 +13,8 @@ Possible values for this model are \code{"unknown"}, \code{"regression"}, or \code{"classification"}.} \item{neighbors}{A single integer for the number of neighbors -to consider (often called \code{k}).} +to consider (often called \code{k}). For \pkg{kknn}, a value of 5 +is used if \code{neighbors} is not specified.} \item{weight_func}{A \emph{single} character for the type of kernel function used to weight distances between samples. Valid choices are: \code{"rectangular"}, diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 68fcb5a7c..85b3df178 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -18,11 +18,11 @@ test_that('primary arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = expr(missing_arg()) + ks = 5 ) ) - neighbors <- nearest_neighbor(mode = "classification", neighbors = 5) + neighbors <- nearest_neighbor(mode = "classification", neighbors = 2) neighbors_kknn <- translate(neighbors %>% set_engine("kknn")) expect_equal( @@ -30,7 +30,7 @@ test_that('primary arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = new_empty_quosure(5) + ks = new_empty_quosure(2) ) ) @@ -42,7 +42,7 @@ test_that('primary arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = expr(missing_arg()), + ks = 5, kernel = new_empty_quosure("triangular") ) ) @@ -55,7 +55,7 @@ test_that('primary arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = expr(missing_arg()), + ks = 5, distance = new_empty_quosure(2) ) ) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 1418c50ae..37ea2e262 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -63,6 +63,7 @@ test_that('kknn execution', { test_that('kknn prediction', { skip_if_not_installed("kknn") + library(kknn) # continuous res_xy <- fit_xy( From 4b0808f9f5270f06a6d5e8e7d5162dd843ef9852 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Fri, 31 May 2019 19:19:25 -0400 Subject: [PATCH 30/64] start of unit tests --- NAMESPACE | 1 + R/aaa_models.R | 8 ++ man/check_mod_val.Rd | 1 - man/get_model_env.Rd | 6 ++ tests/testthat/test_registration.R | 130 +++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/test_registration.R diff --git a/NAMESPACE b/NAMESPACE index d0ee74cf0..427150a53 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -96,6 +96,7 @@ export(fit_xy) export(fit_xy.model_spec) export(get_dependency) export(get_fit) +export(get_from_env) export(get_model_env) export(get_pred_type) export(keras_mlp) diff --git a/R/aaa_models.R b/R/aaa_models.R index a2cf0fde8..c5552f533 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -711,4 +711,12 @@ show_model_info <- function(model) { invisible(NULL) } +#' @rdname get_model_env +#' @keywords internal +#' @export +#' @param items A character string of objects in the model environment. +get_from_env <- function(items) { + mod_env <- get_model_env() + env_get(mod_env, items) +} diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd index 7779985dc..6efe8d8d8 100644 --- a/man/check_mod_val.Rd +++ b/man/check_mod_val.Rd @@ -107,7 +107,6 @@ format. show_model_info("rand_forest") # Access the model data: - current_code <- get_model_env() ls(envir = current_code) diff --git a/man/get_model_env.Rd b/man/get_model_env.Rd index 76ff615f2..c2aa37bbb 100644 --- a/man/get_model_env.Rd +++ b/man/get_model_env.Rd @@ -13,6 +13,7 @@ \alias{set_pred} \alias{get_pred_type} \alias{show_model_info} +\alias{get_from_env} \title{Tools to Register Models} \usage{ get_model_env() @@ -38,6 +39,11 @@ set_pred(model, mode, eng, type, value) get_pred_type(model, type) show_model_info(model) + +get_from_env(items) +} +\arguments{ +\item{items}{A character string of objects in the model environment.} } \description{ Tools to Register Models diff --git a/tests/testthat/test_registration.R b/tests/testthat/test_registration.R new file mode 100644 index 000000000..ceaa6f655 --- /dev/null +++ b/tests/testthat/test_registration.R @@ -0,0 +1,130 @@ +library(parsnip) +library(dplyr) +library(rlang) +library(testthat) + +# ------------------------------------------------------------------------------ + +context("model registration") +#source("helpers.R") + +test_by_col <- function(a, b) { + for(i in union(names(a), names(b))) { + expect_equal(a[[i]], b[[i]]) + } +} + +# ------------------------------------------------------------------------------ + +test_that('adding a new model', { + set_new_model("sponge") + + mod_items <- get_model_env() %>% env_names() + sponges <- grep("sponge", mod_items, value = TRUE) + exp_obj <- c('sponge_modes', 'sponge_fit', 'sponge_args', + 'sponge_predict', 'sponge_pkgs', 'sponge') + expect_equal(sort(sponges), sort(exp_obj)) + + expect_equal( + get_from_env("sponge"), + tibble(engine = character(0), mode = character(0)) + ) + +test_by_col( + get_from_env("sponge_pkgs"), + tibble(engine = character(0), pkg = character(0)) +) + +expect_equal( + get_from_env("sponge_modes"), "unknown" +) + +test_by_col( + get_from_env("sponge_args"), + tibble(engine = character(0), parsnip = character(0), + original = character(0), func = vector("list")) +) + +test_by_col( + get_from_env("sponge_fit"), + tibble(engine = character(0), mode = character(0), value = vector("list")) +) + +test_by_col( + get_from_env("sponge_predict"), + tibble(engine = character(0), mode = character(0), + type = character(0), value = vector("list")) +) + +expect_error(set_new_model()) +# TODO expect_error(set_new_model(2)) +# TODO expect_error(set_new_model(letters[1:2])) +}) + + +# ------------------------------------------------------------------------------ + +test_that('adding a new mode', { + set_model_mode("sponge", "classification") + + expect_equal(get_from_env("sponge_modes"), c("unknown", "classification")) + + # TODO expect_error(set_model_mode("sponge", "banana")) + # TODO expect_error(set_model_mode("sponge", "classification")) + +}) + + +# ------------------------------------------------------------------------------ + +test_that('adding a new engine', { + set_model_engine("sponge", "classification", "gum") + + test_by_col( + get_from_env("sponge"), + tibble(engine = "gum", mode = "classification") + ) + + + expect_equal(get_from_env("sponge_modes"), c("unknown", "classification")) + + # TODO check for bad mode, check for duplicate + +}) + + +# ------------------------------------------------------------------------------ + +test_that('adding a new package', { + set_dependency("sponge", "gum", "trident") + + expect_error(set_dependency("sponge", "gum", letters[1:2])) + + test_by_col( + get_from_env("sponge_pkgs"), + tibble(engine = "gum", pkg = list("trident")) + ) +}) + + +# ------------------------------------------------------------------------------ + +test_that('adding a new argument', { + +}) + + + +# ------------------------------------------------------------------------------ + +test_that('adding a new fit', { + +}) + + +# ------------------------------------------------------------------------------ + +test_that('adding a new predict method', { + +}) + From e29daa2c5c01318d3db994b6d397040d4956d407 Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 21:21:01 -0400 Subject: [PATCH 31/64] fixed some test cases --- R/aaa.R | 2 +- R/aaa_models.R | 2 +- tests/testthat/test_nearest_neighbor.R | 12 +- tests/testthat/test_registration.R | 292 ++++++++++++++++++++++++- 4 files changed, 298 insertions(+), 10 deletions(-) diff --git a/R/aaa.R b/R/aaa.R index b91d6f4c5..230c2b08c 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -23,5 +23,5 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { #' @importFrom utils globalVariables utils::globalVariables( c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', - 'lab', 'original', 'predicted_label', 'prediction', 'value') + 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type') ) diff --git a/R/aaa_models.R b/R/aaa_models.R index c5552f533..08f5a9514 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -717,6 +717,6 @@ show_model_info <- function(model) { #' @param items A character string of objects in the model environment. get_from_env <- function(items) { mod_env <- get_model_env() - env_get(mod_env, items) + rlang::env_get(mod_env, items) } diff --git a/tests/testthat/test_nearest_neighbor.R b/tests/testthat/test_nearest_neighbor.R index 85b3df178..c2a52196f 100644 --- a/tests/testthat/test_nearest_neighbor.R +++ b/tests/testthat/test_nearest_neighbor.R @@ -42,8 +42,8 @@ test_that('primary arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = 5, - kernel = new_empty_quosure("triangular") + kernel = new_empty_quosure("triangular"), + ks = 5 ) ) @@ -55,8 +55,8 @@ test_that('primary arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = 5, - distance = new_empty_quosure(2) + distance = new_empty_quosure(2), + ks = 5 ) ) @@ -71,8 +71,8 @@ test_that('engine arguments', { expected = list( formula = expr(missing_arg()), data = expr(missing_arg()), - ks = expr(missing_arg()), - scale = new_empty_quosure(FALSE) + scale = new_empty_quosure(FALSE), + ks = 5 ) ) diff --git a/tests/testthat/test_registration.R b/tests/testthat/test_registration.R index ceaa6f655..20941061d 100644 --- a/tests/testthat/test_registration.R +++ b/tests/testthat/test_registration.R @@ -8,8 +8,9 @@ library(testthat) context("model registration") #source("helpers.R") +# There's currently an issue comparing tibbles so we do it col-by-col test_by_col <- function(a, b) { - for(i in union(names(a), names(b))) { + for (i in union(names(a), names(b))) { expect_equal(a[[i]], b[[i]]) } } @@ -32,7 +33,7 @@ test_that('adding a new model', { test_by_col( get_from_env("sponge_pkgs"), - tibble(engine = character(0), pkg = character(0)) + tibble(engine = character(0), pkg = list()) ) expect_equal( @@ -110,6 +111,83 @@ test_that('adding a new package', { # ------------------------------------------------------------------------------ test_that('adding a new argument', { + set_model_arg( + model = "sponge", + eng = "gum", + parsnip = "modeling", + original = "modelling", + func = list(pkg = "foo", fun = "bar"), + has_submodel = FALSE + ) + + test_by_col( + get_from_env("sponge_args"), + tibble(engine = "gum", parsnip = "modeling", original = "modelling", + has_submodel = FALSE) + ) + + expect_error( + set_model_arg( + model = "lunchroom", + eng = "gum", + parsnip = "modeling", + original = "modelling", + func = list(pkg = "foo", fun = "bar"), + has_submodel = FALSE + ) + ) + + expect_error( + set_model_arg( + model = "sponge", + eng = "gum", + parsnip = "modeling", + func = list(pkg = "foo", fun = "bar"), + has_submodel = FALSE + ) + ) + + expect_error( + set_model_arg( + model = "sponge", + eng = "gum", + original = "modelling", + func = list(pkg = "foo", fun = "bar"), + has_submodel = FALSE + ) + ) + + expect_error( + set_model_arg( + model = "sponge", + eng = "gum", + parsnip = "modeling", + original = "modelling", + func = "foo::bar", + has_submodel = FALSE + ) + ) + + expect_error( + set_model_arg( + model = "sponge", + eng = "gum", + parsnip = "modeling", + original = "modelling", + func = list(pkg = "foo", fun = "bar"), + has_submodel = 2 + ) + ) + + expect_error( + set_model_arg( + model = "sponge", + eng = "gum", + parsnip = "modeling", + original = "modelling", + func = list(pkg = "foo", fun = "bar") + ) + ) }) @@ -118,13 +196,223 @@ test_that('adding a new argument', { # ------------------------------------------------------------------------------ test_that('adding a new fit', { + fit_vals <- + list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "foo", fun = "bar"), + defaults = list() + ) + + set_fit( + model = "sponge", + eng = "gum", + mode = "classification", + value = fit_vals + ) + + fit_env_data <- get_from_env("sponge_fit") + test_by_col( + fit_env_data[ 1:2], + tibble(engine = "gum", mode = "classification") + ) + + expect_equal( + fit_env_data$value[[1]], + fit_vals + ) + expect_error( + set_fit( + model = "cactus", + eng = "gum", + mode = "classification", + value = fit_vals + ) + ) + + expect_error( + set_fit( + model = "sponge", + eng = "nose", + mode = "classification", + value = fit_vals + ) + ) + + expect_error( + set_fit( + model = "sponge", + eng = "gum", + mode = "frog", + value = fit_vals + ) + ) + + for (i in 1:length(fit_vals)) { + expect_error( + set_fit( + model = "sponge", + eng = "gum", + mode = "classification", + value = fit_vals[-i] + ) + ) + } + + fit_vals_0 <- fit_vals + fit_vals_0$interface <- "loaf" + expect_error( + set_fit( + model = "sponge", + eng = "gum", + mode = "classification", + value = fit_vals_0 + ) + ) + + fit_vals_1 <- fit_vals + fit_vals_1$defaults <- 2 + expect_error( + set_fit( + model = "sponge", + eng = "gum", + mode = "classification", + value = fit_vals_1 + ) + ) + + fit_vals_2 <- fit_vals + fit_vals_2$func <- "foo:bar" + expect_error( + set_fit( + model = "sponge", + eng = "gum", + mode = "classification", + value = fit_vals_2 + ) + ) }) # ------------------------------------------------------------------------------ test_that('adding a new predict method', { + class_vals <- + list( + pre = I, + post = NULL, + func = c(fun = "predict"), + args = list(x = quote(2)) + ) + + set_pred( + model = "sponge", + eng = "gum", + mode = "classification", + type = "class", + value = class_vals + ) + + pred_env_data <- get_from_env("sponge_predict") + test_by_col( + pred_env_data[ 1:3], + tibble(engine = "gum", mode = "classification", type = "class") + ) + + expect_equal( + pred_env_data$value[[1]], + class_vals + ) + + expect_error( + set_pred( + model = "cactus", + eng = "gum", + mode = "classification", + type = "class", + value = class_vals + ) + ) + + expect_error( + set_pred( + model = "sponge", + eng = "nose", + mode = "classification", + type = "class", + value = class_vals + ) + ) + + + expect_error( + set_pred( + model = "sponge", + eng = "gum", + mode = "classification", + type = "eggs", + value = class_vals + ) + ) + + expect_error( + set_pred( + model = "sponge", + eng = "gum", + mode = "frog", + type = "class", + value = class_vals + ) + ) + + for (i in 1:length(class_vals)) { + expect_error( + set_pred( + model = "sponge", + eng = "gum", + mode = "classification", + type = "class", + value = class_vals[-i] + ) + ) + } + + class_vals_0 <- class_vals + class_vals_0$pre <- "I" + expect_error( + set_pred( + model = "sponge", + eng = "gum", + mode = "classification", + type = "class", + value = class_vals_0 + ) + ) + + class_vals_1 <- class_vals + class_vals_1$post <- "I" + expect_error( + set_pred( + model = "sponge", + eng = "gum", + mode = "classification", + type = "class", + value = class_vals_1 + ) + ) + + class_vals_2 <- class_vals + class_vals_2$func <- "foo:bar" + expect_error( + set_pred( + model = "sponge", + eng = "gum", + mode = "classification", + type = "class", + value = class_vals_2 + ) + ) }) From 697ca0ec8414a908a97047936ba4245b3b81c3de Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 21:33:42 -0400 Subject: [PATCH 32/64] figuring out travis issues --- tests/testthat/test_surv_reg_survreg.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/testthat/test_surv_reg_survreg.R b/tests/testthat/test_surv_reg_survreg.R index 6b66b4134..a8e05b135 100644 --- a/tests/testthat/test_surv_reg_survreg.R +++ b/tests/testthat/test_surv_reg_survreg.R @@ -19,6 +19,8 @@ quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) test_that('survival execution', { + skip() + expect_error( res <- fit( surv_basic, @@ -48,6 +50,7 @@ test_that('survival execution', { }) test_that('survival prediction', { + skip() res <- fit( surv_basic, From 13d43ab972f53ba48d45fa818c582a2c1ff17773 Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 21:33:56 -0400 Subject: [PATCH 33/64] added mode specification --- vignettes/articles/Classification.Rmd | 1 + 1 file changed, 1 insertion(+) diff --git a/vignettes/articles/Classification.Rmd b/vignettes/articles/Classification.Rmd index 132344126..4963c7aca 100644 --- a/vignettes/articles/Classification.Rmd +++ b/vignettes/articles/Classification.Rmd @@ -58,6 +58,7 @@ test_normalized <- bake(credit_rec, new_data = credit_test, all_predictors()) set.seed(57974) nnet_fit <- mlp(epochs = 100, hidden_units = 5, dropout = 0.1) %>% + set_mode("classification") %>% # Also set engine-specific arguments: set_engine("keras", verbose = 0, validation_split = .20) %>% fit(Status ~ ., data = juice(credit_rec)) From b869d6af0d807478cc610f64d9319b620a8261cf Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 21:45:56 -0400 Subject: [PATCH 34/64] bug fixes --- vignettes/articles/Models.Rmd | 28 +++++++--------------------- vignettes/articles/Scratch.Rmd | 3 ++- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/vignettes/articles/Models.Rmd b/vignettes/articles/Models.Rmd index 8fd627782..653893c72 100644 --- a/vignettes/articles/Models.Rmd +++ b/vignettes/articles/Models.Rmd @@ -22,27 +22,15 @@ library(cli) library(kableExtra) ``` -```{r modelinfo, include = FALSE} -mod_names <- function(model, engine) { - obj_name <- paste(model, engine, "data", sep = "_") - tibble(module = getFromNamespace(obj_name, "parsnip") %>% names(), - model = model, - engine = engine) -} - -engine_info <- - parsnip:::engine_info %>% - distinct(model, engine) %>% - mutate(obj_name = paste(model, engine, "data", sep = "_")) -``` - `parsnip` contains wrappers for a number of models. For example, the `parsnip` function `rand_forest()` can be used to create a random forest model. The **mode** of a model is related to its goal. Examples would be regression and classification. The list of models accessible via `parsnip` is: ```{r model-table, results = 'asis', echo = FALSE} +mod_names <- get_from_env("models") + mod_list <- - parsnip:::engine_info %>% + map_dfr(mod_names, ~ get_from_env(.x) %>% mutate(model = .x)) %>% distinct(mode, model) %>% mutate(model = paste0("`", model, "()`")) %>% arrange(mode, model) %>% @@ -59,16 +47,14 @@ for (i in 1:nrow(mod_list)) { _How_ the model is created is related to the _engine_. In many cases, this is an R modeling package. In others, it may be a connection to an external system (such as Spark or Tensorflow). This table lists the engines for each model type along with the type of prediction that it can make (see `predict.model_fit()`). ```{r pred-table, results = 'asis', echo = FALSE} - map2_dfr(engine_info$model, engine_info$engine, mod_names) %>% - dplyr::filter(!(module %in% c("libs", "fit"))) %>% +map_dfr(mod_names, ~ get_from_env(paste0(.x, "_predict")) %>% mutate(model = .x)) %>% + dplyr::select(-value) %>% mutate( - module = ifelse(module == "confint", "conf_int", module), - module = ifelse(module == "predint", "pred_int", module), - module = paste0("`", module, "`"), + type = paste0("`", type, "`"), model = paste0("`", model, "()`"), ) %>% mutate(check = cli::symbol$tick) %>% - spread(module, check, fill = cli::symbol$times) %>% + spread(type, check, fill = cli::symbol$times) %>% kable(format = "html") %>% kable_styling(full_width = FALSE) %>% collapse_rows(columns = 1) diff --git a/vignettes/articles/Scratch.Rmd b/vignettes/articles/Scratch.Rmd index 81e0282e3..bc9a90d58 100644 --- a/vignettes/articles/Scratch.Rmd +++ b/vignettes/articles/Scratch.Rmd @@ -269,7 +269,8 @@ iris_split <- initial_split(iris, prop = 0.95) iris_train <- training(iris_split) iris_test <- testing(iris_split) -mda_spec <- mixture_da(sub_classes = 2) +mda_spec <- mixture_da(sub_classes = 2) %>% + set_engine("mda") mda_fit <- mda_spec %>% fit(Species ~ ., data = iris_train, engine = "mda") From 2700f96f9d6779e9fbf55e3e1616c3b39e182473 Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 21:46:08 -0400 Subject: [PATCH 35/64] more travis investigations --- tests/testthat/test_surv_reg_survreg.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_surv_reg_survreg.R b/tests/testthat/test_surv_reg_survreg.R index a8e05b135..1d2ca9dd3 100644 --- a/tests/testthat/test_surv_reg_survreg.R +++ b/tests/testthat/test_surv_reg_survreg.R @@ -19,7 +19,7 @@ quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) test_that('survival execution', { - skip() + skip("troubleshooting travis") expect_error( res <- fit( @@ -50,7 +50,7 @@ test_that('survival execution', { }) test_that('survival prediction', { - skip() + skip("troubleshooting travis") res <- fit( surv_basic, From c69037b9e3bdb9f49a65f24835865ca5efd17fac Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 22:04:00 -0400 Subject: [PATCH 36/64] fixed rd files --- R/aaa_models.R | 47 +++++++++++++++++++--------------------- _pkgdown.yml | 14 +++++++++++- man/check_mod_val.Rd | 43 ++++++++++++++++++++++++++++++++++++- man/get_model_env.Rd | 51 -------------------------------------------- 4 files changed, 77 insertions(+), 78 deletions(-) delete mode 100644 man/get_model_env.Rd diff --git a/R/aaa_models.R b/R/aaa_models.R index 08f5a9514..de5b87c77 100644 --- a/R/aaa_models.R +++ b/R/aaa_models.R @@ -40,18 +40,6 @@ pred_types <- #' Tools to Register Models #' -#' @keywords internal -#' @export -get_model_env <- function() { - current <- utils::getFromNamespace("parsnip", ns = "parsnip") - # current <- parsnip - current -} - - - -#' Tools to Check Model Elements -#' #' These functions are similar to constructors and can be used to validate #' that there are no conflicts with the underlying model structures used by the #' package. @@ -89,6 +77,7 @@ get_model_env <- function() { #' underlying model function uses. #' @param value A list that conforms to the `fit_obj` or `pred_obj` description #' above, depending on context. +#' @param items A character string of objects in the model environment. #' @keywords internal #' @details These functions are available for users to add their #' own models or engines (in package or otherwise) so that they can @@ -137,6 +126,15 @@ check_mod_val <- function(model, new = FALSE, existence = FALSE) { invisible(NULL) } +#' @rdname check_mod_val +#' @keywords internal +#' @export +get_model_env <- function() { + current <- utils::getFromNamespace("parsnip", ns = "parsnip") + # current <- parsnip + current +} + #' @rdname check_mod_val #' @keywords internal #' @export @@ -295,7 +293,7 @@ check_pkg_val <- function(pkg) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_new_model <- function(model) { @@ -333,7 +331,7 @@ set_new_model <- function(model) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_model_mode <- function(model, mode) { @@ -353,7 +351,7 @@ set_model_mode <- function(model, mode) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_model_engine <- function(model, mode, eng) { @@ -378,7 +376,7 @@ set_model_engine <- function(model, mode, eng) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { @@ -417,7 +415,7 @@ set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_dependency <- function(model, eng, pkg) { @@ -460,7 +458,7 @@ set_dependency <- function(model, eng, pkg) { invisible(NULL) } -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export get_dependency <- function(model) { @@ -475,7 +473,7 @@ get_dependency <- function(model) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_fit <- function(model, mode, eng, value) { @@ -526,7 +524,7 @@ set_fit <- function(model, mode, eng, value) { invisible(NULL) } -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export get_fit <- function(model) { @@ -540,7 +538,7 @@ get_fit <- function(model) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export set_pred <- function(model, mode, eng, type, value) { @@ -592,7 +590,7 @@ set_pred <- function(model, mode, eng, type, value) { invisible(NULL) } -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export get_pred_type <- function(model, type) { @@ -618,7 +616,7 @@ validate_model <- function(model) { # ------------------------------------------------------------------------------ -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export show_model_info <- function(model) { @@ -711,10 +709,9 @@ show_model_info <- function(model) { invisible(NULL) } -#' @rdname get_model_env +#' @rdname check_mod_val #' @keywords internal #' @export -#' @param items A character string of objects in the model environment. get_from_env <- function(items) { mod_env <- get_model_env() rlang::env_get(mod_env, items) diff --git a/_pkgdown.yml b/_pkgdown.yml index 0f4430bf4..644889998 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -53,7 +53,19 @@ reference: contents: - lending_club - wa_churn - - check_times + - check_times + + - title: Developer Tools + contents: + - set_new_model + - starts_with("set_model_") + - set_dependency + - set_fit + - set_pred + - show_model_info + - starts_with("get_") + + navbar: left: diff --git a/man/check_mod_val.Rd b/man/check_mod_val.Rd index 6efe8d8d8..ed32406d9 100644 --- a/man/check_mod_val.Rd +++ b/man/check_mod_val.Rd @@ -4,6 +4,7 @@ \name{pred_types} \alias{pred_types} \alias{check_mod_val} +\alias{get_model_env} \alias{check_mode_val} \alias{check_engine_val} \alias{check_arg_val} @@ -12,13 +13,27 @@ \alias{check_fit_info} \alias{check_pred_info} \alias{check_pkg_val} -\title{Tools to Check Model Elements} +\alias{set_new_model} +\alias{set_model_mode} +\alias{set_model_engine} +\alias{set_model_arg} +\alias{set_dependency} +\alias{get_dependency} +\alias{set_fit} +\alias{get_fit} +\alias{set_pred} +\alias{get_pred_type} +\alias{show_model_info} +\alias{get_from_env} +\title{Tools to Register Models} \format{An object of class \code{character} of length 8.} \usage{ pred_types check_mod_val(model, new = FALSE, existence = FALSE) +get_model_env() + check_mode_val(mode) check_engine_val(eng) @@ -34,6 +49,30 @@ check_fit_info(fit_obj) check_pred_info(pred_obj, type) check_pkg_val(pkg) + +set_new_model(model) + +set_model_mode(model, mode) + +set_model_engine(model, mode, eng) + +set_model_arg(model, eng, parsnip, original, func, has_submodel) + +set_dependency(model, eng, pkg) + +get_dependency(model) + +set_fit(model, mode, eng, value) + +get_fit(model) + +set_pred(model, mode, eng, type, value) + +get_pred_type(model, type) + +show_model_info(model) + +get_from_env(items) } \arguments{ \item{model}{A single character string for the model type (e.g. @@ -83,6 +122,8 @@ underlying model function uses.} \item{value}{A list that conforms to the \code{fit_obj} or \code{pred_obj} description above, depending on context.} + +\item{items}{A character string of objects in the model environment.} } \description{ These functions are similar to constructors and can be used to validate diff --git a/man/get_model_env.Rd b/man/get_model_env.Rd deleted file mode 100644 index c2aa37bbb..000000000 --- a/man/get_model_env.Rd +++ /dev/null @@ -1,51 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/aaa_models.R -\name{get_model_env} -\alias{get_model_env} -\alias{set_new_model} -\alias{set_model_mode} -\alias{set_model_engine} -\alias{set_model_arg} -\alias{set_dependency} -\alias{get_dependency} -\alias{set_fit} -\alias{get_fit} -\alias{set_pred} -\alias{get_pred_type} -\alias{show_model_info} -\alias{get_from_env} -\title{Tools to Register Models} -\usage{ -get_model_env() - -set_new_model(model) - -set_model_mode(model, mode) - -set_model_engine(model, mode, eng) - -set_model_arg(model, eng, parsnip, original, func, has_submodel) - -set_dependency(model, eng, pkg) - -get_dependency(model) - -set_fit(model, mode, eng, value) - -get_fit(model) - -set_pred(model, mode, eng, type, value) - -get_pred_type(model, type) - -show_model_info(model) - -get_from_env(items) -} -\arguments{ -\item{items}{A character string of objects in the model environment.} -} -\description{ -Tools to Register Models -} -\keyword{internal} From c7054fedff482fb50b5d91ae85d02aae1aea1829 Mon Sep 17 00:00:00 2001 From: topepo Date: Fri, 31 May 2019 22:05:48 -0400 Subject: [PATCH 37/64] version exclusions from rmarkdown and partykit --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index f7f839773..279c2ee9a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,6 +17,7 @@ r: matrix: allow_failures: - r: 3.2 + - r: oldrel r_binary_packages: - RCurl From 648a9d0af5662929477a0495dd408e492689619b Mon Sep 17 00:00:00 2001 From: topepo Date: Sat, 1 Jun 2019 12:30:30 -0400 Subject: [PATCH 38/64] test change and pkgdown update --- tests/testthat/test_surv_reg_survreg.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_surv_reg_survreg.R b/tests/testthat/test_surv_reg_survreg.R index 1d2ca9dd3..5dbc14ce7 100644 --- a/tests/testthat/test_surv_reg_survreg.R +++ b/tests/testthat/test_surv_reg_survreg.R @@ -19,7 +19,7 @@ quiet_ctrl <- fit_control(verbosity = 0, catch = TRUE) test_that('survival execution', { - skip("troubleshooting travis") + skip_on_travis() expect_error( res <- fit( @@ -50,7 +50,7 @@ test_that('survival execution', { }) test_that('survival prediction', { - skip("troubleshooting travis") + skip_on_travis() res <- fit( surv_basic, From 3c23479436192b21641f5fa4b2f09b34da825a16 Mon Sep 17 00:00:00 2001 From: topepo Date: Sat, 1 Jun 2019 12:30:55 -0400 Subject: [PATCH 39/64] pkgdown docs --- .../dev/articles/articles/Classification.html | 71 +- docs/dev/articles/articles/Models.html | 803 ++++++++++++++---- docs/dev/articles/articles/Regression.html | 46 +- .../figure-html/glmn-pred-1.png | Bin 224760 -> 225062 bytes docs/dev/articles/articles/Scratch.html | 663 +++++++-------- docs/dev/articles/index.html | 2 +- docs/dev/articles/parsnip_Intro.html | 16 +- docs/dev/index.html | 28 +- docs/dev/news/index.html | 41 +- docs/dev/pkgdown.css | 27 +- docs/dev/pkgdown.js | 9 +- docs/dev/pkgdown.yml | 6 +- docs/dev/reference/add_rowindex.html | 219 +++++ docs/dev/reference/boost_tree.html | 20 +- docs/dev/reference/check_mod_val.html | 418 +++++++++ docs/dev/reference/check_times.html | 4 +- docs/dev/reference/decision_tree.html | 20 +- docs/dev/reference/descriptors.html | 6 +- docs/dev/reference/fit.html | 25 +- docs/dev/reference/fit_control.html | 4 +- docs/dev/reference/get_model_env.html | 191 +++++ docs/dev/reference/index.html | 21 + docs/dev/reference/keras_mlp.html | 2 +- docs/dev/reference/lending_club.html | 4 +- docs/dev/reference/linear_reg.html | 24 +- docs/dev/reference/logistic_reg.html | 24 +- docs/dev/reference/mars.html | 16 +- docs/dev/reference/mlp.html | 20 +- docs/dev/reference/model_fit.html | 6 +- docs/dev/reference/model_spec.html | 4 +- docs/dev/reference/multinom_reg.html | 24 +- docs/dev/reference/nearest_neighbor.html | 13 +- docs/dev/reference/null_model.html | 4 +- docs/dev/reference/nullmodel.html | 10 +- docs/dev/reference/predict.model_fit.html | 72 +- docs/dev/reference/rand_forest.html | 18 +- docs/dev/reference/reexports.html | 2 +- docs/dev/reference/surv_reg.html | 16 +- docs/dev/reference/svm_poly.html | 16 +- docs/dev/reference/svm_rbf.html | 16 +- docs/dev/reference/translate.html | 2 +- .../reference/varying_args.model_spec.html | 60 +- docs/dev/reference/wa_churn.html | 4 +- 43 files changed, 2190 insertions(+), 807 deletions(-) create mode 100644 docs/dev/reference/add_rowindex.html create mode 100644 docs/dev/reference/check_mod_val.html create mode 100644 docs/dev/reference/get_model_env.html diff --git a/docs/dev/articles/articles/Classification.html b/docs/dev/articles/articles/Classification.html index f702a65ee..9f5ef1ac1 100644 --- a/docs/dev/articles/articles/Classification.html +++ b/docs/dev/articles/articles/Classification.html @@ -101,24 +101,24 @@

Classification Example

To demonstrate parsnip for classification models, the credit data will be used.

-
library(tidymodels)
+
library(tidymodels)
 #> Registered S3 method overwritten by 'xts':
 #>   method     from
 #>   as.zoo.xts zoo
-#> ── Attaching packages ───────────────────────────────── tidymodels 0.0.2 ──
+#> ── Attaching packages ──────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ──
 #> ✔ broom     0.5.1       ✔ purrr     0.3.2  
-#> ✔ dials     0.0.2       ✔ recipes   0.1.4  
+#> ✔ dials     0.0.2       ✔ recipes   0.1.5  
 #> ✔ dplyr     0.8.0.1     ✔ rsample   0.0.4  
 #> ✔ infer     0.4.0       ✔ yardstick 0.0.2
-#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ──
+#> ── Conflicts ─────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
 #> ✖ purrr::discard() masks scales::discard()
 #> ✖ dplyr::filter()  masks stats::filter()
 #> ✖ dplyr::lag()     masks stats::lag()
 #> ✖ recipes::step()  masks stats::step()
 
-data(credit_data)
+data(credit_data)
 
-set.seed(7075)
+set.seed(7075)
 data_split <- initial_split(credit_data, strata = "Status", p = 0.75)
 
 credit_train <- training(data_split)
@@ -136,41 +136,42 @@ 

Classification Example

test_normalized <- bake(credit_rec, new_data = credit_test, all_predictors())

keras will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model.

-
set.seed(57974)
+
+  set_mode("classification") %>% 
+  # Also set engine-specific arguments: 
+  set_engine("keras", verbose = 0, validation_split = .20) %>%
+  fit(Status ~ ., data = juice(credit_rec))
+
+nnet_fit
+#> parsnip model object
+#> 
+#> Model
+#> ___________________________________________________________________________
+#> Layer (type)                     Output Shape                  Param #     
+#> ===========================================================================
+#> dense_1 (Dense)                  (None, 5)                     115         
+#> ___________________________________________________________________________
+#> dense_2 (Dense)                  (None, 5)                     30          
+#> ___________________________________________________________________________
+#> dropout_1 (Dropout)              (None, 5)                     0           
+#> ___________________________________________________________________________
+#> dense_3 (Dense)                  (None, 2)                     12          
+#> ===========================================================================
+#> Total params: 157
+#> Trainable params: 157
+#> Non-trainable params: 0
+#> ___________________________________________________________________________

In parsnip, the predict function can be used:.

+#> bad 185 89 +#> good 128 711
+ + + +
+
+

parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

+
+ +
+

+ Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

+
+
+ + + + + + + diff --git a/docs/dev/reference/boost_tree.html b/docs/dev/reference/boost_tree.html index b5450dd08..cf07a7ff1 100644 --- a/docs/dev/reference/boost_tree.html +++ b/docs/dev/reference/boost_tree.html @@ -171,7 +171,7 @@

General Interface for Boosted Trees

time that the model is fit. Other options and argument can be set using the set_engine() function. If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

@@ -181,7 +181,7 @@

General Interface for Boosted Trees

loss_reduction = NULL, sample_size = NULL) # S3 method for boost_tree -update(object, mtry = NULL, trees = NULL, +update(object, mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, sample_size = NULL, fresh = FALSE, ...) @@ -242,7 +242,7 @@

Arg ... -

Not used for update().

+

Not used for update().

@@ -255,9 +255,9 @@

Details

The data given to the function are not saved and are only used to determine the mode of the model. For boost_tree(), the possible modes are "regression" and "classification".

-

The model can be created using the fit() function using the +

The model can be created using the fit() function using the following engines:

    -
  • R: "xgboost", "C5.0"

  • +
  • R: "xgboost" (the default), "C5.0"

  • Spark: "spark"

@@ -265,13 +265,13 @@

Note

For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

@@ -309,7 +309,7 @@

See also

- +

Examples

@@ -328,12 +328,12 @@

Examp #> Main Arguments: #> mtry = 10 #> min_n = 3 -#>
update(model, mtry = 1)
#> Boosted Tree Model Specification (unknown) +#>
update(model, mtry = 1)
#> Boosted Tree Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 #> min_n = 3 -#>
update(model, mtry = 1, fresh = TRUE)
#> Boosted Tree Model Specification (unknown) +#>
update(model, mtry = 1, fresh = TRUE)
#> Boosted Tree Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 diff --git a/docs/dev/reference/check_mod_val.html b/docs/dev/reference/check_mod_val.html new file mode 100644 index 000000000..5de0f14f1 --- /dev/null +++ b/docs/dev/reference/check_mod_val.html @@ -0,0 +1,418 @@ + + + + + + + + +Tools to Register Models — pred_types • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + +
+ +
+
+ + +
+ +

These functions are similar to constructors and can be used to validate +that there are no conflicts with the underlying model structures used by the +package.

+ +
+ +
pred_types
+
+check_mod_val(model, new = FALSE, existence = FALSE)
+
+get_model_env()
+
+check_mode_val(mode)
+
+check_engine_val(eng)
+
+check_arg_val(arg)
+
+check_submodels_val(has_submodel)
+
+check_func_val(func)
+
+check_fit_info(fit_obj)
+
+check_pred_info(pred_obj, type)
+
+check_pkg_val(pkg)
+
+set_new_model(model)
+
+set_model_mode(model, mode)
+
+set_model_engine(model, mode, eng)
+
+set_model_arg(model, eng, parsnip, original, func, has_submodel)
+
+set_dependency(model, eng, pkg)
+
+get_dependency(model)
+
+set_fit(model, mode, eng, value)
+
+get_fit(model)
+
+set_pred(model, mode, eng, type, value)
+
+get_pred_type(model, type)
+
+show_model_info(model)
+
+get_from_env(items)
+ +

Arguments

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
model

A single character string for the model type (e.g. +"rand_forest", etc).

new

A single logical to check to see if the model that you are check +has not already been registered.

existence

A single logical to check to see if the model has already +been registered.

mode

A single character string for the model mode (e.g. "regression").

eng

A single character string for the model engine.

arg

A single character string for the model argument name.

has_submodel

A single logical for whether the argument +can make predictions on mutiple submodels at once.

func

A named character vector that describes how to call +a function. func should have elements pkg and fun. The +former is optional but is recommended and the latter is +required. For example, c(pkg = "stats", fun = "lm") would be +used to invoke the usual linear regression function. In some +cases, it is helpful to use c(fun = "predict") when using a +package's predict method.

fit_obj

A list with elements interface, protect, +func and defaults. See the package vignette "Making a +parsnip model from scratch".

pred_obj

A list with elements pre, post, func, and args. +See the package vignette "Making a parsnip model from scratch".

type

A single character value for the type of prediction. Possible +values are: +'raw', 'numeric', 'class', 'link', 'prob', 'conf_int', 'pred_int', 'quantile' +.

pkg

An options character string for a package name.

parsnip

A single character string for the "harmonized" argument name +that parsnip exposes.

original

A single character string for the argument name that +underlying model function uses.

value

A list that conforms to the fit_obj or pred_obj description +above, depending on context.

items

A character string of objects in the model environment.

+ +

Format

+ +

An object of class character of length 8.

+ +

Details

+ +

These functions are available for users to add their +own models or engines (in package or otherwise) so that they can +be accessed using parsnip. This are more thoroughly documented +on the package web site (see references below).

+

In short, parsnip stores an environment object that contains +all of the information and code about how models are used (e.g. +fitting, predicting, etc). These functions can be used to add +models to that environment as well as helper functions that can +be used to makes sure that the model data is in the right +format.

+ +

References

+ +

"Making a parsnip model from scratch" +https://tidymodels.github.io/parsnip/articles/articles/Scratch.html

+ + +

Examples

+
# Show the information about a model: +show_model_info("rand_forest")
#> Information for `rand_forest` +#> modes: unknown, classification, regression +#> +#> engines: +#> classification: randomForest, ranger, spark +#> regression: randomForest, ranger, spark +#> +#> arguments: +#> ranger: +#> mtry --> mtry +#> trees --> num.trees +#> min_n --> min.node.size +#> randomForest: +#> mtry --> mtry +#> trees --> ntree +#> min_n --> nodesize +#> spark: +#> mtry --> feature_subset_strategy +#> trees --> num_trees +#> min_n --> min_instances_per_node +#> +#> fit modules: +#> engine mode +#> ranger classification +#> ranger regression +#> randomForest classification +#> randomForest regression +#> spark classification +#> spark regression +#> +#> prediction modules: +#> mode engine methods +#> classification randomForest class, prob, raw +#> classification ranger class, conf_int, prob, raw +#> classification spark class, prob +#> regression randomForest numeric, raw +#> regression ranger conf_int, numeric, raw +#> regression spark numeric +#>
+# Access the model data: +current_code <- get_model_env() +ls(envir = current_code)
#> [1] "boost_tree" "boost_tree_args" +#> [3] "boost_tree_fit" "boost_tree_modes" +#> [5] "boost_tree_pkgs" "boost_tree_predict" +#> [7] "decision_tree" "decision_tree_args" +#> [9] "decision_tree_fit" "decision_tree_modes" +#> [11] "decision_tree_pkgs" "decision_tree_predict" +#> [13] "linear_reg" "linear_reg_args" +#> [15] "linear_reg_fit" "linear_reg_modes" +#> [17] "linear_reg_pkgs" "linear_reg_predict" +#> [19] "logistic_reg" "logistic_reg_args" +#> [21] "logistic_reg_fit" "logistic_reg_modes" +#> [23] "logistic_reg_pkgs" "logistic_reg_predict" +#> [25] "mars" "mars_args" +#> [27] "mars_fit" "mars_modes" +#> [29] "mars_pkgs" "mars_predict" +#> [31] "mlp" "mlp_args" +#> [33] "mlp_fit" "mlp_modes" +#> [35] "mlp_pkgs" "mlp_predict" +#> [37] "models" "modes" +#> [39] "multinom_reg" "multinom_reg_args" +#> [41] "multinom_reg_fit" "multinom_reg_modes" +#> [43] "multinom_reg_pkgs" "multinom_reg_predict" +#> [45] "nearest_neighbor" "nearest_neighbor_args" +#> [47] "nearest_neighbor_fit" "nearest_neighbor_modes" +#> [49] "nearest_neighbor_pkgs" "nearest_neighbor_predict" +#> [51] "null_model" "null_model_args" +#> [53] "null_model_fit" "null_model_modes" +#> [55] "null_model_pkgs" "null_model_predict" +#> [57] "rand_forest" "rand_forest_args" +#> [59] "rand_forest_fit" "rand_forest_modes" +#> [61] "rand_forest_pkgs" "rand_forest_predict" +#> [63] "surv_reg" "surv_reg_args" +#> [65] "surv_reg_fit" "surv_reg_modes" +#> [67] "surv_reg_pkgs" "surv_reg_predict" +#> [69] "svm_poly" "svm_poly_args" +#> [71] "svm_poly_fit" "svm_poly_modes" +#> [73] "svm_poly_pkgs" "svm_poly_predict" +#> [75] "svm_rbf" "svm_rbf_args" +#> [77] "svm_rbf_fit" "svm_rbf_modes" +#> [79] "svm_rbf_pkgs" "svm_rbf_predict"
+
+
+ +
+ +
+
+

parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

+
+ +
+

+ Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

+
+
+
+ + + + + + diff --git a/docs/dev/reference/check_times.html b/docs/dev/reference/check_times.html index 08fff9fe5..0b6c131d5 100644 --- a/docs/dev/reference/check_times.html +++ b/docs/dev/reference/check_times.html @@ -195,8 +195,8 @@

Details

Examples

-
data(check_times) -str(check_times)
#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: +
data(check_times) +str(check_times)
#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: #> $ package : chr "A3" "abbyyR" "abc" "abc.data" ... #> $ authors : int 1 1 1 1 5 3 2 1 4 6 ... #> $ imports : num 0 6 0 0 3 1 0 4 0 7 ... diff --git a/docs/dev/reference/decision_tree.html b/docs/dev/reference/decision_tree.html index bbb7cf34b..22825643e 100644 --- a/docs/dev/reference/decision_tree.html +++ b/docs/dev/reference/decision_tree.html @@ -159,7 +159,7 @@

General Interface for Decision Tree Models

time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

@@ -168,7 +168,7 @@

General Interface for Decision Tree Models

tree_depth = NULL, min_n = NULL) # S3 method for decision_tree -update(object, cost_complexity = NULL, +update(object, cost_complexity = NULL, tree_depth = NULL, min_n = NULL, fresh = FALSE, ...)

Arguments

@@ -205,15 +205,15 @@

Arg ... -

Not used for update().

+

Not used for update().

Details

-

The model can be created using the fit() function using the +

The model can be created using the fit() function using the following engines:

    -
  • R: "rpart" or "C5.0" (classification only)

  • +
  • R: "rpart" (the default) or "C5.0" (classification only)

  • Spark: "spark"

Note that, for rpart models, but cost_complexity and @@ -226,13 +226,13 @@

Note

For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

@@ -269,7 +269,7 @@

See also

- +

Examples

@@ -288,12 +288,12 @@

Examp #> Main Arguments: #> cost_complexity = 10 #> min_n = 3 -#>

update(model, cost_complexity = 1)
#> Decision Tree Model Specification (unknown) +#>
update(model, cost_complexity = 1)
#> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 #> min_n = 3 -#>
update(model, cost_complexity = 1, fresh = TRUE)
#> Decision Tree Model Specification (unknown) +#>
update(model, cost_complexity = 1, fresh = TRUE)
#> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 diff --git a/docs/dev/reference/descriptors.html b/docs/dev/reference/descriptors.html index cf0396989..22694b7e7 100644 --- a/docs/dev/reference/descriptors.html +++ b/docs/dev/reference/descriptors.html @@ -134,7 +134,7 @@

Data Set Characteristics Available when Fitting Models

-

When using the fit() functions there are some +

When using the fit() functions there are some variables that will be available for use in arguments. For example, if the user would like to choose an argument value based on the current number of rows in a data set, the .obs() @@ -174,7 +174,7 @@

Details
  • .y(): The known outcomes returned in the format given. Either a vector, matrix, or data frame.

  • .dat(): A data frame containing all of the predictors and the -outcomes. If fit_xy() was used, the outcomes are attached as the +outcomes. If fit_xy() was used, the outcomes are attached as the column, ..y.

  • For example, if you use the model formula Sepal.Width ~ . with the iris @@ -200,7 +200,7 @@

    Details

    To use these in a model fit, pass them to a model specification. The evaluation is delayed until the time when the -model is run via fit() (and the variables listed above are available). +model is run via fit() (and the variables listed above are available). For example:

         data("lending_club")
    diff --git a/docs/dev/reference/fit.html b/docs/dev/reference/fit.html
    index 59c76b4e1..8bd1bbf94 100644
    --- a/docs/dev/reference/fit.html
    +++ b/docs/dev/reference/fit.html
    @@ -132,18 +132,18 @@ 

    Fit a Model Specification to a Dataset

    -

    fit() and fit_xy() take a model specification, translate the required +

    fit() and fit_xy() take a model specification, translate the required code by substituting arguments, and execute the model fit routine.

    # S3 method for model_spec
    -fit(object, formula = NULL, data = NULL,
    +fit(object, formula = NULL, data = NULL,
       control = fit_control(), ...)
     
     # S3 method for model_spec
    -fit_xy(object, x = NULL, y = NULL,
    +fit_xy(object, x = NULL, y = NULL,
       control = fit_control(), ...)

    Arguments

    @@ -206,21 +206,24 @@

    Value

    Details

    -

    fit() and fit_xy() substitute the current arguments in the model +

    fit() and fit_xy() substitute the current arguments in the model specification into the computational engine's code, checks them for validity, then fits the model using the data and the engine-specific code. Different model functions have different interfaces (e.g. formula or x/y) and these functions translate -between the interface used when fit() or fit_xy() were invoked and the one +between the interface used when fit() or fit_xy() were invoked and the one required by the underlying model.

    When possible, these functions attempt to avoid making copies of the data. For example, if the underlying model uses a formula and -fit() is invoked, the original data are references +fit() is invoked, the original data are references when the model is fit. However, if the underlying model uses something else, such as x/y, the formula is evaluated and the data are converted to the required format. In this case, any calls in the resulting model objects reference the temporary objects used to fit the model.

    +

    If the model engine has not been set, the model's default engine will be used +(as discussed on each model page). If the verbosity option of +fit_control() is greater than zero, a warning will be produced.

    See also

    @@ -231,28 +234,26 @@

    Examp
    # Although `glm()` only has a formula interface, different # methods for specifying the model can be used -library(dplyr)
    #> +library(dplyr)
    #> #> Attaching package: ‘dplyr’
    #> The following object is masked from ‘package:testthat’: #> #> matches
    #> The following objects are masked from ‘package:stats’: #> #> filter, lag
    #> The following objects are masked from ‘package:base’: #> -#> intersect, setdiff, setequal, union
    data("lending_club") - -lr_mod <- logistic_reg() +#> intersect, setdiff, setequal, union
    data("lending_club") lr_mod <- logistic_reg() using_formula <- lr_mod %>% set_engine("glm") %>% - fit(Class ~ funded_amnt + int_rate, data = lending_club) + fit(Class ~ funded_amnt + int_rate, data = lending_club) using_xy <- lr_mod %>% set_engine("glm") %>% - fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], + fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], y = lending_club$Class) using_formula
    #> parsnip model object diff --git a/docs/dev/reference/fit_control.html b/docs/dev/reference/fit_control.html index aa73f18a4..d93ee2181 100644 --- a/docs/dev/reference/fit_control.html +++ b/docs/dev/reference/fit_control.html @@ -131,7 +131,7 @@

    Control the fit function

    -

    Options can be passed to the fit() function that control the output and +

    Options can be passed to the fit() function that control the output and computations

    @@ -154,7 +154,7 @@

    Arg catch

    A logical where a value of TRUE will evaluate -the model inside of try(, silent = TRUE). If the model fails, +the model inside of try(, silent = TRUE). If the model fails, an object is still returned (without an error) that inherits the class "try-error".

    diff --git a/docs/dev/reference/get_model_env.html b/docs/dev/reference/get_model_env.html new file mode 100644 index 000000000..0c2b8d9bb --- /dev/null +++ b/docs/dev/reference/get_model_env.html @@ -0,0 +1,191 @@ + + + + + + + + +Tools to Register Models — get_model_env • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + +
    + +
    +
    + + +
    + +

    Tools to Register Models

    + +
    + +
    get_model_env()
    +
    +set_new_model(model)
    +
    +set_model_mode(model, mode)
    +
    +set_model_engine(model, mode, eng)
    +
    +set_model_arg(model, eng, parsnip, original, func, has_submodel)
    +
    +set_dependency(model, eng, pkg)
    +
    +get_dependency(model)
    +
    +set_fit(model, mode, eng, value)
    +
    +get_fit(model)
    +
    +set_pred(model, mode, eng, type, value)
    +
    +get_pred_type(model, type)
    +
    +show_model_info(model)
    +
    +get_from_env(items)
    + + +
    + +
    + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    +
    +
    + + + + + + diff --git a/docs/dev/reference/index.html b/docs/dev/reference/index.html index 3252d3df9..0540feb14 100644 --- a/docs/dev/reference/index.html +++ b/docs/dev/reference/index.html @@ -226,6 +226,12 @@

    add_rowindex()

    + +

    Add a column of row numbers to a data frame

    + +

    .cols() .preds() .obs() .lvls() .facts() .x() .y() .dat()

    @@ -329,6 +335,20 @@

    <

    Execution Time Data

    + + + +

    Developer Tools

    +

    + + + + + +

    pred_types check_mod_val() get_model_env() check_mode_val() check_engine_val() check_arg_val() check_submodels_val() check_func_val() check_fit_info() check_pred_info() check_pkg_val() set_new_model() set_model_mode() set_model_engine() set_model_arg() set_dependency() get_dependency() set_fit() get_fit() set_pred() get_pred_type() show_model_info() get_from_env()

    + +

    Tools to Register Models

    +

    @@ -339,6 +359,7 @@

    Contents

  • Models
  • Infrastructure
  • Data
  • +
  • Developer Tools
  • diff --git a/docs/dev/reference/keras_mlp.html b/docs/dev/reference/keras_mlp.html index 51ffd154c..ed96c15ed 100644 --- a/docs/dev/reference/keras_mlp.html +++ b/docs/dev/reference/keras_mlp.html @@ -139,7 +139,7 @@

    Simple interface to MLP models via keras

    keras_mlp(x, y, hidden_units = 5, decay = 0, dropout = 0,
    -  epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3),
    +  epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3),
       ...)

    Arguments

    diff --git a/docs/dev/reference/lending_club.html b/docs/dev/reference/lending_club.html index 30bcfe282..bb62cc077 100644 --- a/docs/dev/reference/lending_club.html +++ b/docs/dev/reference/lending_club.html @@ -158,8 +158,8 @@

    Details

    Examples

    -
    data(lending_club) -str(lending_club)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: +
    data(lending_club) +str(lending_club)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: #> $ funded_amnt : int 16100 32000 10000 16800 3500 10000 11000 15000 6000 20000 ... #> $ term : Factor w/ 2 levels "term_36","term_60": 1 2 1 2 1 1 1 1 1 2 ... #> $ int_rate : num 13.99 11.99 16.29 13.67 7.39 ... diff --git a/docs/dev/reference/linear_reg.html b/docs/dev/reference/linear_reg.html index 35a3375f8..31dfc2fbf 100644 --- a/docs/dev/reference/linear_reg.html +++ b/docs/dev/reference/linear_reg.html @@ -155,7 +155,7 @@

    General Interface for Linear Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +163,7 @@

    General Interface for Linear Regression Models

    linear_reg(mode = "regression", penalty = NULL, mixture = NULL)
     
     # S3 method for linear_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,7 +200,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -209,9 +209,9 @@

    Details

    The data given to the function are not saved and are only used to determine the mode of the model. For linear_reg(), the mode will always be "regression".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "lm" or "glmnet"

    • +
    • R: "lm" (the default) or "glmnet"

    • Stan: "stan"

    • Spark: "spark"

    • keras: "keras"

    • @@ -221,13 +221,13 @@

      Note

      For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

      @@ -265,8 +265,8 @@

      See also

      - +

      Examples

      @@ -296,12 +296,12 @@

      Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>

    update(model, penalty = 1)
    #> Linear Regression Model Specification (regression) +#>
    update(model, penalty = 1)
    #> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
    update(model, penalty = 1, fresh = TRUE)
    #> Linear Regression Model Specification (regression) +#>
    update(model, penalty = 1, fresh = TRUE)
    #> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 diff --git a/docs/dev/reference/logistic_reg.html b/docs/dev/reference/logistic_reg.html index 7dde682bf..ca4332916 100644 --- a/docs/dev/reference/logistic_reg.html +++ b/docs/dev/reference/logistic_reg.html @@ -155,7 +155,7 @@

    General Interface for Logistic Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +163,7 @@

    General Interface for Logistic Regression Models

    logistic_reg(mode = "classification", penalty = NULL, mixture = NULL)
     
     # S3 method for logistic_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,16 +200,16 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    For logistic_reg(), the mode will always be "classification".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "glm" or "glmnet"

    • +
    • R: "glm" (the default) or "glmnet"

    • Stan: "stan"

    • Spark: "spark"

    • keras: "keras"

    • @@ -219,13 +219,13 @@

      Note

      For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

      @@ -264,8 +264,8 @@

      See also

      - +

      Examples

      @@ -296,12 +296,12 @@

      Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>
      update(model, penalty = 1)
      #> Logistic Regression Model Specification (classification) +#>
      update(model, penalty = 1)
      #> Logistic Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
      update(model, penalty = 1, fresh = TRUE)
      #> Logistic Regression Model Specification (classification) +#>
      update(model, penalty = 1, fresh = TRUE)
      #> Logistic Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 diff --git a/docs/dev/reference/mars.html b/docs/dev/reference/mars.html index 58bc0fe89..fafee1c3a 100644 --- a/docs/dev/reference/mars.html +++ b/docs/dev/reference/mars.html @@ -161,7 +161,7 @@

      General Interface for MARS

      time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

      @@ -170,7 +170,7 @@

      General Interface for MARS

      prune_method = NULL) # S3 method for mars -update(object, num_terms = NULL, prod_degree = NULL, +update(object, num_terms = NULL, prod_degree = NULL, prune_method = NULL, fresh = FALSE, ...)

    Arguments

    @@ -206,15 +206,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "earth"

    • +
    • R: "earth" (the default)

    Engine Details

    @@ -239,7 +239,7 @@

    See also

    - +

    Examples

    @@ -253,12 +253,12 @@

    Examp #> Main Arguments: #> num_terms = 10 #> prune_method = none -#>
    update(model, num_terms = 1)
    #> MARS Model Specification (unknown) +#>
    update(model, num_terms = 1)
    #> MARS Model Specification (unknown) #> #> Main Arguments: #> num_terms = 1 #> prune_method = none -#>
    update(model, num_terms = 1, fresh = TRUE)
    #> MARS Model Specification (unknown) +#>
    update(model, num_terms = 1, fresh = TRUE)
    #> MARS Model Specification (unknown) #> #> Main Arguments: #> num_terms = 1 diff --git a/docs/dev/reference/mlp.html b/docs/dev/reference/mlp.html index 851e6f3d7..e9b746b82 100644 --- a/docs/dev/reference/mlp.html +++ b/docs/dev/reference/mlp.html @@ -170,7 +170,7 @@

    General Interface for Single Layer Neural Network

    dropout = NULL, epochs = NULL, activation = NULL) # S3 method for mlp -update(object, hidden_units = NULL, penalty = NULL, +update(object, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE, ...) @@ -220,7 +220,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -230,15 +230,15 @@

    Details time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (see above), the values are taken from the underlying model -functions. One exception is hidden_units when nnet::nnet is used; that +functions. One exception is hidden_units when nnet::nnet is used; that function's size argument has no default so a value of 5 units will be used. Also, unless otherwise specified, the linout argument to -nnet::nnet() will be set to TRUE when a regression model is created. -If parameters need to be modified, update() can be used +nnet::nnet() will be set to TRUE when a regression model is created. +If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "nnet"

    • +
    • R: "nnet" (the default)

    • keras: "keras"

    An error is thrown if both penalty and dropout are specified for @@ -271,7 +271,7 @@

    See also

    - +

    Examples

    @@ -290,12 +290,12 @@

    Examp #> Main Arguments: #> hidden_units = 10 #> dropout = 0.3 -#>

    update(model, hidden_units = 2)
    #> Single Layer Neural Network Specification (unknown) +#>
    update(model, hidden_units = 2)
    #> Single Layer Neural Network Specification (unknown) #> #> Main Arguments: #> hidden_units = 2 #> dropout = 0.3 -#>
    update(model, hidden_units = 2, fresh = TRUE)
    #> Single Layer Neural Network Specification (unknown) +#>
    update(model, hidden_units = 2, fresh = TRUE)
    #> Single Layer Neural Network Specification (unknown) #> #> Main Arguments: #> hidden_units = 2 diff --git a/docs/dev/reference/model_fit.html b/docs/dev/reference/model_fit.html index 23afa17ec..a800662ad 100644 --- a/docs/dev/reference/model_fit.html +++ b/docs/dev/reference/model_fit.html @@ -168,7 +168,7 @@

    Examp # Keep the `x` matrix if the data are not too big. spec_obj <- linear_reg() %>% - set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) + set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) spec_obj

    #> Linear Regression Model Specification (regression) #> #> Engine-Specific Arguments: @@ -176,7 +176,7 @@

    Examp #> #> Computational engine: lm #>

    -fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) +fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) fit_obj
    #> parsnip model object #> #> @@ -190,7 +190,7 @@

    Examp #> qsec vs am gear carb #> 0.82104 0.31776 2.52023 0.65541 -0.19942 #>

    -nrow(fit_obj$fit$x)
    #> [1] 32
    +nrow(fit_obj$fit$x)
    #> [1] 32
    update(model, penalty = 1)
    #> Multinomial Regression Model Specification (classification) +#>
    update(model, penalty = 1)
    #> Multinomial Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
    update(model, penalty = 1, fresh = TRUE)
    #> Multinomial Regression Model Specification (classification) +#>
    update(model, penalty = 1, fresh = TRUE)
    #> Multinomial Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 diff --git a/docs/dev/reference/nearest_neighbor.html b/docs/dev/reference/nearest_neighbor.html index 4b9d49f44..d22c15948 100644 --- a/docs/dev/reference/nearest_neighbor.html +++ b/docs/dev/reference/nearest_neighbor.html @@ -161,7 +161,7 @@

    General Interface for K-Nearest Neighbor Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -181,7 +181,8 @@

    Arg neighbors

    A single integer for the number of neighbors -to consider (often called k).

    +to consider (often called k). For kknn, a value of 5 +is used if neighbors is not specified.

    weight_func @@ -199,9 +200,9 @@

    Arg

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "kknn"

    • +
    • R: "kknn" (the default)

    Note

    @@ -221,12 +222,12 @@

    kknn (classification or regression)

     kknn::train.kknn(formula = missing_arg(), data = missing_arg(), 
    -    kmax = missing_arg())
    +    ks = 5)
     

    See also

    - +

    Examples

    diff --git a/docs/dev/reference/null_model.html b/docs/dev/reference/null_model.html index a7e171c4d..53eb934a3 100644 --- a/docs/dev/reference/null_model.html +++ b/docs/dev/reference/null_model.html @@ -153,7 +153,7 @@

    Arg

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    • R: "parsnip"

    @@ -175,7 +175,7 @@

    See also

    - +

    Examples

    diff --git a/docs/dev/reference/nullmodel.html b/docs/dev/reference/nullmodel.html index b38efb183..ed7909a78 100644 --- a/docs/dev/reference/nullmodel.html +++ b/docs/dev/reference/nullmodel.html @@ -142,10 +142,10 @@

    Fit a simple, non-informative model

    nullmodel(x = NULL, y, ...) # S3 method for nullmodel -print(x, ...) +print(x, ...) # S3 method for nullmodel -predict(object, new_data = NULL, type = NULL, ...) +predict(object, new_data = NULL, type = NULL, ...)

    Arguments

    @@ -207,13 +207,13 @@

    Details

    Examples

    -outcome <- factor(sample(letters[1:2], +outcome <- factor(sample(letters[1:2], size = 100, - prob = c(.1, .9), + prob = c(.1, .9), replace = TRUE)) useless <- nullmodel(y = outcome) useless
    #> Null Regression Model -#> Predicted Value: b
    predict(useless, matrix(NA, nrow = 5))
    #> [1] b b b b b +#> Predicted Value: b
    predict(useless, matrix(NA, nrow = 5))
    #> [1] b b b b b #> Levels: a b
    diff --git a/docs/dev/reference/predict.model_fit.html b/docs/dev/reference/predict.model_fit.html index 6a6c0aa45..4c6df5f30 100644 --- a/docs/dev/reference/predict.model_fit.html +++ b/docs/dev/reference/predict.model_fit.html @@ -133,14 +133,14 @@

    Model predictions

    Apply a model to create different types of predictions. -predict() can be used for all types of models and used the +predict() can be used for all types of models and used the "type" argument for more specificity.

    # S3 method for model_fit
    -predict(object, new_data, type = NULL,
    -  opts = list(), ...)
    +predict(object, new_data, type=NULL, + opts=list(), ...)

    Arguments

    @@ -157,7 +157,7 @@

    Arg

    @@ -213,16 +213,16 @@

    Value

    appear in names and 2) vectors are never returned but type-specific prediction functions.

    When the model fit failed and the error was captured, the -predict() function will return the same structure as above but +predict() function will return the same structure as above but filled with missing values. This does not currently work for multivariate models.

    Details

    -

    If "type" is not supplied to predict(), then a choice +

    If "type" is not supplied to predict(), then a choice is made (type = "numeric" for regression models and type = "class" for classification).

    -

    predict() is designed to provide a tidy result (see "Value" +

    predict() is designed to provide a tidy result (see "Value" section below) in a tibble output format.

    When using type = "conf_int" and type = "pred_int", the options level and std_error can be used. The latter is a logical for an @@ -230,54 +230,54 @@

    Details

    Examples

    -
    library(dplyr) +
    library(dplyr) lm_model <- linear_reg() %>% set_engine("lm") %>% - fit(mpg ~ ., data = mtcars %>% slice(11:32)) + fit(mpg ~ ., data = mtcars %>% slice(11:32)) pred_cars <- mtcars %>% slice(1:10) %>% select(-mpg) -predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 +predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 #> .pred -#> <dbl> -#> 1 23.4 -#> 2 23.3 -#> 3 27.6 -#> 4 21.5 -#> 5 17.6 -#> 6 21.6 -#> 7 13.9 -#> 8 21.7 -#> 9 25.6 -#> 10 17.1
    -predict( +#> <dbl> +#> 1 23.4 +#> 2 23.3 +#> 3 27.6 +#> 4 21.5 +#> 5 17.6 +#> 6 21.6 +#> 7 13.9 +#> 8 21.7 +#> 9 25.6 +#> 10 17.1
    +predict( lm_model, pred_cars, type = "conf_int", level = 0.90 -)
    #> # A tibble: 10 x 2 +)
    #> # A tibble: 10 x 2 #> .pred_lower .pred_upper -#> <dbl> <dbl> -#> 1 17.9 29.0 -#> 2 18.1 28.5 -#> 3 24.0 31.3 -#> 4 17.5 25.6 -#> 5 14.3 20.8 -#> 6 17.0 26.2 -#> 7 9.65 18.2 -#> 8 16.2 27.2 -#> 9 14.2 37.0 -#> 10 11.5 22.7
    -predict( +#> <dbl> <dbl> +#> 1 17.9 29.0 +#> 2 18.1 28.5 +#> 3 24.0 31.3 +#> 4 17.5 25.6 +#> 5 14.3 20.8 +#> 6 17.0 26.2 +#> 7 9.65 18.2 +#> 8 16.2 27.2 +#> 9 14.2 37.0 +#> 10 11.5 22.7
    +predict( lm_model, pred_cars, type = "raw", - opts = list(type = "terms") + opts = list(type = "terms") )
    #> cyl disp hp drat wt qsec #> 1 -0.001433177 -0.8113275 0.6303467 -0.06120265 2.4139815 -1.567729 #> 2 -0.001433177 -0.8113275 0.6303467 -0.06120265 1.4488706 -0.736286 diff --git a/docs/dev/reference/rand_forest.html b/docs/dev/reference/rand_forest.html index 663b28935..1dea25281 100644 --- a/docs/dev/reference/rand_forest.html +++ b/docs/dev/reference/rand_forest.html @@ -157,7 +157,7 @@

    General Interface for Random Forest Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -166,7 +166,7 @@

    General Interface for Random Forest Models

    min_n = NULL) # S3 method for rand_forest -update(object, mtry = NULL, trees = NULL, +update(object, mtry = NULL, trees = NULL, min_n = NULL, fresh = FALSE, ...)

    Arguments

    @@ -204,15 +204,15 @@

    Arg

    - +
    type

    A single character value or NULL. Possible values are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", -or "raw". When NULL, predict() will choose an appropriate value +or "raw". When NULL, predict() will choose an appropriate value based on the model's mode.

    ...

    Not used for update().

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "ranger" or "randomForest"

    • +
    • R: "ranger" (the default) or "randomForest"

    • Spark: "spark"

    @@ -220,7 +220,7 @@

    Note

    For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor @@ -274,7 +274,7 @@

    See also

    - +

    Examples

    @@ -293,12 +293,12 @@

    Examp #> Main Arguments: #> mtry = 10 #> min_n = 3 -#>
    update(model, mtry = 1)
    #> Random Forest Model Specification (unknown) +#>
    update(model, mtry = 1)
    #> Random Forest Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 #> min_n = 3 -#>
    update(model, mtry = 1, fresh = TRUE)
    #> Random Forest Model Specification (unknown) +#>
    update(model, mtry = 1, fresh = TRUE)
    #> Random Forest Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 diff --git a/docs/dev/reference/reexports.html b/docs/dev/reference/reexports.html index 699b80a7d..3f8af8092 100644 --- a/docs/dev/reference/reexports.html +++ b/docs/dev/reference/reexports.html @@ -139,7 +139,7 @@

    Objects exported from other packages

    These objects are imported from other packages. Follow the links below to see their documentation.

    -
    generics

    fit, fit_xy, varying_args

    +
    generics

    fit, fit_xy, varying_args

    magrittr

    %>%

    diff --git a/docs/dev/reference/surv_reg.html b/docs/dev/reference/surv_reg.html index 476112894..31b3d2a9a 100644 --- a/docs/dev/reference/surv_reg.html +++ b/docs/dev/reference/surv_reg.html @@ -159,7 +159,7 @@

    General Interface for Parametric Survival Models

    surv_reg(mode = "regression", dist = NULL)
     
     # S3 method for surv_reg
    -update(object, dist = NULL, fresh = FALSE, ...)
    +update(object, dist = NULL, fresh = FALSE, ...)

    Arguments

    @@ -185,7 +185,7 @@

    Arg

    - +
    ...

    Not used for update().

    Not used for update().

    @@ -195,15 +195,15 @@

    Details to determine the mode of the model. For surv_reg(),the mode will always be "regression".

    Since survival models typically involve censoring (and require the use of -survival::Surv() objects), the fit() function will require that the +survival::Surv() objects), the fit() function will require that the survival model be specified via the formula interface.

    Also, for the flexsurv::flexsurvfit engine, the typical strata function cannot be used. To achieve the same effect, the extra parameter roles can be used (as described above).

    For surv_reg(), the mode will always be "regression".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "flexsurv", "survreg"

    • +
    • R: "flexsurv", "survival" (the default)

    Engine Details

    @@ -217,7 +217,7 @@

    survreg

    +

    survival

     survival::survreg(formula = missing_arg(), data = missing_arg(), 
         weights = missing_arg(), model = TRUE)
    @@ -233,7 +233,7 @@ 

    R

    See also

    - +

    Examples

    @@ -249,7 +249,7 @@

    Examp #> #> Main Arguments: #> dist = weibull -#>

    update(model, dist = "lnorm")
    #> Parametric Survival Regression Model Specification (regression) +#>
    update(model, dist = "lnorm")
    #> Parametric Survival Regression Model Specification (regression) #> #> Main Arguments: #> dist = lnorm diff --git a/docs/dev/reference/svm_poly.html b/docs/dev/reference/svm_poly.html index 4b5ea1519..3d59e9275 100644 --- a/docs/dev/reference/svm_poly.html +++ b/docs/dev/reference/svm_poly.html @@ -159,7 +159,7 @@

    General interface for polynomial support vector machines

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +168,7 @@

    General interface for polynomial support vector machines

    scale_factor = NULL, margin = NULL) # S3 method for svm_poly -update(object, cost = NULL, degree = NULL, +update(object, cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, fresh = FALSE, ...)

    Arguments

    @@ -209,15 +209,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "kernlab"

    • +
    • R: "kernlab" (the default)

    Engine Details

    @@ -238,7 +238,7 @@

    See also

    - +

    Examples

    @@ -257,12 +257,12 @@

    Examp #> Main Arguments: #> cost = 10 #> scale_factor = 0.1 -#>
    update(model, cost = 1)
    #> Polynomial Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1)
    #> Polynomial Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 #> scale_factor = 0.1 -#>
    update(model, cost = 1, fresh = TRUE)
    #> Polynomial Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1, fresh = TRUE)
    #> Polynomial Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 diff --git a/docs/dev/reference/svm_rbf.html b/docs/dev/reference/svm_rbf.html index b1ae58eb4..7b4e682a4 100644 --- a/docs/dev/reference/svm_rbf.html +++ b/docs/dev/reference/svm_rbf.html @@ -159,7 +159,7 @@

    General interface for radial basis function support vector machines

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +168,7 @@

    General interface for radial basis function support vector machines

    margin = NULL) # S3 method for svm_rbf -update(object, cost = NULL, rbf_sigma = NULL, +update(object, cost = NULL, rbf_sigma = NULL, margin = NULL, fresh = FALSE, ...)

    Arguments

    @@ -205,15 +205,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "kernlab"

    • +
    • R: "kernlab" (the default)

    Engine Details

    @@ -234,7 +234,7 @@

    See also

    - +

    Examples

    @@ -253,12 +253,12 @@

    Examp #> Main Arguments: #> cost = 10 #> rbf_sigma = 0.1 -#>
    update(model, cost = 1)
    #> Radial Basis Function Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1)
    #> Radial Basis Function Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 #> rbf_sigma = 0.1 -#>
    update(model, cost = 1, fresh = TRUE)
    #> Radial Basis Function Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1, fresh = TRUE)
    #> Radial Basis Function Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 diff --git a/docs/dev/reference/translate.html b/docs/dev/reference/translate.html index 26329274a..0b00a588c 100644 --- a/docs/dev/reference/translate.html +++ b/docs/dev/reference/translate.html @@ -157,7 +157,7 @@

    Details

    translate() produces a template call that lacks the specific argument values (such as data, etc). These are filled in once -fit() is called with the specifics of the data for the model. +fit() is called with the specifics of the data for the model. The call may also include varying arguments if these are in the specification.

    It does contain the resolved argument names that are specific to diff --git a/docs/dev/reference/varying_args.model_spec.html b/docs/dev/reference/varying_args.model_spec.html index 22a240d23..526546ac9 100644 --- a/docs/dev/reference/varying_args.model_spec.html +++ b/docs/dev/reference/varying_args.model_spec.html @@ -132,20 +132,20 @@

    Determine varying arguments

    -

    varying_args() takes a model specification or a recipe and returns a tibble +

    varying_args() takes a model specification or a recipe and returns a tibble of information on all possible varying arguments and whether or not they are actually varying.

    # S3 method for model_spec
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)
     
     # S3 method for recipe
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)
     
     # S3 method for step
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)

    Arguments

    @@ -182,50 +182,50 @@

    Details

    Examples

    # List all possible varying args for the random forest spec -rand_forest() %>% varying_args()
    #> # A tibble: 3 x 4 +rand_forest() %>% varying_args()
    #> # A tibble: 3 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec
    # mtry is now recognized as varying -rand_forest(mtry = varying()) %>% varying_args()
    #> # A tibble: 3 x 4 +rand_forest(mtry = varying()) %>% varying_args()
    #> # A tibble: 3 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry TRUE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry TRUE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec
    # Even engine specific arguments can vary rand_forest() %>% set_engine("ranger", sample.fraction = varying()) %>% - varying_args()
    #> # A tibble: 4 x 4 + varying_args()
    #> # A tibble: 4 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec -#> 4 sample.fraction TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec +#> 4 sample.fraction TRUE rand_forest model_spec
    # List only the arguments that actually vary rand_forest() %>% set_engine("ranger", sample.fraction = varying()) %>% - varying_args(full = FALSE)
    #> # A tibble: 1 x 4 + varying_args(full = FALSE)
    #> # A tibble: 1 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 sample.fraction TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 sample.fraction TRUE rand_forest model_spec
    rand_forest() %>% set_engine( "randomForest", strata = Class, sampsize = varying() ) %>% - varying_args()
    #> # A tibble: 5 x 4 + varying_args()
    #> # A tibble: 5 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec -#> 4 strata FALSE rand_forest model_spec -#> 5 sampsize TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec +#> 4 strata FALSE rand_forest model_spec +#> 5 sampsize TRUE rand_forest model_spec