diff --git a/.travis.yml b/.travis.yml index f7f839773..8cd134b71 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,48 +4,27 @@ language: r dist: trusty sudo: true +cran: https://cran.rstudio.com + # until generics is finalized warnings_are_errors: false r: - 3.2 - - oldrel - - release + - 3.3 + - 3.4 + - 3.5 + - 3.6 - devel matrix: allow_failures: - - r: 3.2 - -r_binary_packages: - - RCurl - - dplyr - - glue - - magrittr - - stringi - - stringr - - munsell - - rlang - - reshape2 - - scales - - tibble - - ggplot2 - - Rcpp - - RcppEigen - - BH - - glmnet - - earth - - sparklyr - - flexsurv - - ranger - - randomforest - - xgboost - - C50 - + - r: 3.3 # inum install failure (seg fault) + - r: 3.2 # partykit install failure (libcoin needs >= 3.4.0) + - r: 3.4 # mvtnorm requires >= 3.5.0 cache: - packages: true directories: - $HOME/.keras - $HOME/.cache/pip @@ -69,12 +48,12 @@ before_script: before_install: - - sudo apt-get -y install libnlopt-dev + - sudo apt-get -y install libnlopt-dev - sudo apt-get update - sudo apt-get -y install python3 - mkdir -p ~/.R && echo "CXX14=g++-6" > ~/.R/Makevars - echo "CXX14FLAGS += -fPIC" >> ~/.R/Makevars - + after_success: - Rscript -e 'covr::codecov()' diff --git a/DESCRIPTION b/DESCRIPTION index 8de5c838c..4d305e962 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: parsnip -Version: 0.0.2 +Version: 0.0.2.9000 Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc). Authors@R: c( @@ -27,7 +27,8 @@ Imports: magrittr, stats, tidyr, - globals + globals, + vctrs Roxygen: list(markdown = TRUE) RoxygenNote: 6.1.1 Suggests: @@ -49,3 +50,4 @@ Suggests: rpart, MASS, nlme +Remotes: r-lib/vctrs diff --git a/NAMESPACE b/NAMESPACE index f24856ce0..a214f738d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -45,6 +45,7 @@ S3method(translate,decision_tree) S3method(translate,default) S3method(translate,mars) S3method(translate,mlp) +S3method(translate,nearest_neighbor) S3method(translate,rand_forest) S3method(translate,surv_reg) S3method(translate,svm_poly) @@ -85,6 +86,11 @@ export(fit.model_spec) export(fit_control) export(fit_xy) export(fit_xy.model_spec) +export(get_dependency) +export(get_fit) +export(get_from_env) +export(get_model_env) +export(get_pred_type) export(keras_mlp) export(linear_reg) export(logistic_reg) @@ -97,13 +103,24 @@ export(multinom_reg) export(nearest_neighbor) export(null_model) export(nullmodel) +export(pred_value_template) export(predict.model_fit) export(rand_forest) export(rpart_train) export(set_args) +export(set_dependency) export(set_engine) +export(set_env_val) +export(set_fit) +export(set_in_env) export(set_mode) +export(set_model_arg) +export(set_model_engine) +export(set_model_mode) +export(set_new_model) +export(set_pred) export(show_call) +export(show_model_info) export(surv_reg) export(svm_poly) export(svm_rbf) @@ -180,7 +197,6 @@ importFrom(stats,na.omit) importFrom(stats,na.pass) importFrom(stats,predict) importFrom(stats,qnorm) -importFrom(stats,qt) importFrom(stats,quantile) importFrom(stats,setNames) importFrom(stats,terms) @@ -194,4 +210,4 @@ importFrom(utils,capture.output) importFrom(utils,getFromNamespace) importFrom(utils,globalVariables) importFrom(utils,head) -importFrom(utils,stack) +importFrom(vctrs,vec_unique) diff --git a/NEWS.md b/NEWS.md index df598bdee..b135f4295 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,11 @@ # parsnip 0.0.2.9000 +## Breaking Changes + + * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html). + * The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation). + * For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`. + ## New Features * `add_rowindex()` can create a column called `.row` to a data frame. diff --git a/R/aaa.R b/R/aaa.R index c8fd5a72e..230c2b08c 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -18,3 +18,10 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { res } +# ------------------------------------------------------------------------------ + +#' @importFrom utils globalVariables +utils::globalVariables( + c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', + 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type') + ) diff --git a/R/aaa_models.R b/R/aaa_models.R new file mode 100644 index 000000000..9b8781078 --- /dev/null +++ b/R/aaa_models.R @@ -0,0 +1,765 @@ +# Initialize model environments + +# ------------------------------------------------------------------------------ + +## Rules about model-related information + +### Definitions: + +# - the model is the model type (e.g. "rand_forest", "linear_reg", etc) +# - the model's mode is the species of model such as "classification" or "regression" +# - the engines are within a model and mode and describe the method/implementation +# of the model in question. These are often R package names. + +### The package dependencies are model- and engine-specific. They are used across modes + +### The `fit` information is a list of data that is needed to fit the model. This +### information is specific to an engine and mode. + +### The `predict` information is also list of data that is needed to make some sort +### of prediction on the model object. The possible types are contained in `pred_types` +### and this information is specific to the engine, mode, and type (although there +### are no types across different modes). + +# ------------------------------------------------------------------------------ + + +parsnip <- rlang::new_environment() +parsnip$models <- NULL +parsnip$modes <- c("regression", "classification", "unknown") + +# ------------------------------------------------------------------------------ + +pred_types <- + c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile") + +# ------------------------------------------------------------------------------ + +#' Working with the parsnip model environment +#' +#' These functions read and write to the environment where the package stores +#' information about model specifications. +#' +#' @param items A character string of objects in the model environment. +#' @param ... Named values that will be assigned to the model environment. +#' @param name A single character value for a new symbol in the model environment. +#' @param value A single value for a new value in the model environment. +#' @keywords internal +#' @references "Making a parsnip model from scratch" +#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html} +#' @examples +#' # Access the model data: +#' current_code <- get_model_env() +#' ls(envir = current_code) +#' +#' @keywords internal +#' @export +get_model_env <- function() { + current <- utils::getFromNamespace("parsnip", ns = "parsnip") + current +} + +#' @rdname get_model_env +#' @keywords internal +#' @export +get_from_env <- function(items) { + mod_env <- get_model_env() + rlang::env_get(mod_env, items) +} + +#' @rdname get_model_env +#' @keywords internal +#' @export +set_in_env <- function(...) { + mod_env <- get_model_env() + rlang::env_bind(mod_env, ...) +} + +#' @rdname get_model_env +#' @keywords internal +#' @export +set_env_val <- function(name, value) { + if (length(name) != 1 || !is.character(name)) { + stop("`name` should be a single character value.", call. = FALSE) + } + mod_env <- get_model_env() + x <- list(value) + names(x) <- name + rlang::env_bind(mod_env, !!!x) +} + +# ------------------------------------------------------------------------------ + +check_eng_val <- function(eng) { + if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) + stop("Please supply a character string for an engine name (e.g. `'lm'`)", + call. = FALSE) + invisible(NULL) +} + + +check_model_exists <- function(model) { + if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) { + stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", + call. = FALSE) + } + + current <- get_model_env() + + if (!any(current$models == model)) { + stop("Model `", model, "` has not been registered.", call. = FALSE) + } + + invisible(NULL) +} + +check_model_doesnt_exist <- function(model) { + if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) { + stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", + call. = FALSE) + } + + current <- get_model_env() + + if (any(current$models == model)) { + stop("Model `", model, "` already exists", call. = FALSE) + } + + invisible(NULL) +} + +check_mode_val <- function(mode) { + if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) + stop("Please supply a character string for a mode (e.g. `'regression'`)", + call. = FALSE) + invisible(NULL) +} + +check_engine_val <- function(eng) { + if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) + stop("Please supply a character string for an engine (e.g. `'lm'`)", + call. = FALSE) + invisible(NULL) +} + +check_arg_val <- function(arg) { + if (rlang::is_missing(arg) || length(arg) != 1 || !is.character(arg)) + stop("Please supply a character string for the argument", + call. = FALSE) + invisible(NULL) +} + +check_submodels_val <- function(has_submodel) { + if (!is.logical(has_submodel) || length(has_submodel) != 1) { + stop("The `submodels` argument should be a single logical.", call. = FALSE) + } + invisible(NULL) +} + +check_func_val <- function(func) { + msg <- + paste( + "`func` should be a named vector with element 'fun' and the optional ", + "element 'pkg'. These should both be single character strings." + ) + + if (rlang::is_missing(func) || !is.vector(func) || length(func) > 2) + stop(msg, call. = FALSE) + + nms <- sort(names(func)) + + if (all(is.null(nms))) { + stop(msg, call. = FALSE) + } + + if (length(func) == 1) { + if (isTRUE(any(nms != "fun"))) { + stop(msg, call. = FALSE) + } + } else { + if (!isTRUE(all.equal(nms, c("fun", "pkg")))) { + stop(msg, call. = FALSE) + } + } + + + if (!all(purrr::map_lgl(func, is.character))) { + stop(msg, call. = FALSE) + } + + invisible(NULL) +} + +check_fit_info <- function(fit_obj) { + if (is.null(fit_obj)) { + stop("The `fit` module cannot be NULL.", call. = FALSE) + } + exp_nms <- c("defaults", "func", "interface", "protect") + if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) { + stop("The `fit` module should have elements: ", + paste0("`", exp_nms, "`", collapse = ", "), + call. = FALSE) + } + + check_interface_val(fit_obj$interface) + check_func_val(fit_obj$func) + + if (!is.list(fit_obj$defaults)) { + stop("The `defaults` element should be a list: ", call. = FALSE) + } + + invisible(NULL) +} + +check_pred_info <- function(pred_obj, type) { + if (all(type != pred_types)) { + stop("The prediction type should be one of: ", + paste0("'", pred_types, "'", collapse = ", "), + call. = FALSE) + } + + exp_nms <- c("args", "func", "post", "pre") + if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) { + stop("The `predict` module should have elements: ", + paste0("`", exp_nms, "`", collapse = ", "), + call. = FALSE) + } + + if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) { + stop("The `pre` module should be null or a function: ", + call. = FALSE) + } + if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) { + stop("The `post` module should be null or a function: ", + call. = FALSE) + } + + check_func_val(pred_obj$func) + + if (!is.list(pred_obj$args)) { + stop("The `args` element should be a list. ", call. = FALSE) + } + + invisible(NULL) +} + +check_pkg_val <- function(pkg) { + if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) + stop("Please supply a single character vale for the package name", + call. = FALSE) + invisible(NULL) +} + +check_interface_val <- function(x) { + exp_interf <- c("data.frame", "formula", "matrix") + if (length(x) != 1 || !(x %in% exp_interf)) { + stop("The `interface` element should have a single value of : ", + paste0("`", exp_interf, "`", collapse = ", "), + call. = FALSE) + } + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ + +#' Tools to Register Models +#' +#' These functions are similar to constructors and can be used to validate +#' that there are no conflicts with the underlying model structures used by the +#' package. +#' +#' @param model A single character string for the model type (e.g. +#' `"rand_forest"`, etc). +#' @param mode A single character string for the model mode (e.g. "regression"). +#' @param eng A single character string for the model engine. +#' @param arg A single character string for the model argument name. +#' @param has_submodel A single logical for whether the argument +#' can make predictions on mutiple submodels at once. +#' @param func A named character vector that describes how to call +#' a function. `func` should have elements `pkg` and `fun`. The +#' former is optional but is recommended and the latter is +#' required. For example, `c(pkg = "stats", fun = "lm")` would be +#' used to invoke the usual linear regression function. In some +#' cases, it is helpful to use `c(fun = "predict")` when using a +#' package's `predict` method. +#' @param fit_obj A list with elements `interface`, `protect`, +#' `func` and `defaults`. See the package vignette "Making a +#' `parsnip` model from scratch". +#' @param pred_obj A list with elements `pre`, `post`, `func`, and `args`. +#' See the package vignette "Making a `parsnip` model from scratch". +#' @param type A single character value for the type of prediction. Possible +#' values are: `class`, `conf_int`, `numeric`, `pred_int`, `prob`, `quantile`, +#' and `raw`. +#' @param pkg An options character string for a package name. +#' @param parsnip A single character string for the "harmonized" argument name +#' that `parsnip` exposes. +#' @param original A single character string for the argument name that +#' underlying model function uses. +#' @param value A list that conforms to the `fit_obj` or `pred_obj` description +#' above, depending on context. +#' @param pre,post Optional functions for pre- and post-processing of prediction +#' results. +#' @param ... Optional arguments that should be passed into the `args` slot for +#' prediction objects. +#' @keywords internal +#' @details These functions are available for users to add their +#' own models or engines (in package or otherwise) so that they can +#' be accessed using `parsnip`. This are more thoroughly documented +#' on the package web site (see references below). +#' +#' In short, `parsnip` stores an environment object that contains +#' all of the information and code about how models are used (e.g. +#' fitting, predicting, etc). These functions can be used to add +#' models to that environment as well as helper functions that can +#' be used to makes sure that the model data is in the right +#' format. +#' +#' `check_model_exists()` checks the model value and ensures that the model has +#' already been registered. `check_model_doesnt_exist()` checks the model value +#' and also checks to see if it is novel in the environment. +#' +#' @references "Making a parsnip model from scratch" +#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html} +#' @examples +#' # set_new_model("shallow_learning_model") +#' +#' # Show the information about a model: +#' show_model_info("rand_forest") +#' @keywords internal +#' @export +set_new_model <- function(model) { + check_model_doesnt_exist(model) + + current <- get_model_env() + + set_env_val("models", c(current$models, model)) + set_env_val(model, dplyr::tibble(engine = character(0), mode = character(0))) + set_env_val( + paste0(model, "_pkgs"), + dplyr::tibble(engine = character(0), pkg = list()) + ) + set_env_val(paste0(model, "_modes"), "unknown") + set_env_val( + paste0(model, "_args"), + dplyr::tibble( + engine = character(0), + parsnip = character(0), + original = character(0), + func = list(), + has_submodel = logical(0) + ) + ) + set_env_val( + paste0(model, "_fit"), + dplyr::tibble( + engine = character(0), + mode = character(0), + value = list() + ) + ) + set_env_val( + paste0(model, "_predict"), + dplyr::tibble( + engine = character(0), + mode = character(0), + type = character(0), + value = list() + ) + ) + + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_model_mode <- function(model, mode) { + check_model_exists(model) + check_mode_val(mode) + + current <- get_model_env() + + if (!any(current$modes == mode)) { + current$modes <- unique(c(current$modes, mode)) + } + + set_env_val( + paste0(model, "_modes"), + unique(c(get_from_env(paste0(model, "_modes")), mode)) + ) + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_model_engine <- function(model, mode, eng) { + check_model_exists(model) + check_mode_val(mode) + check_eng_val(eng) + check_mode_val(eng) + + current <- get_model_env() + + new_eng <- dplyr::tibble(engine = eng, mode = mode) + old_eng <- get_from_env(model) + + engs <- + old_eng %>% + dplyr::bind_rows(new_eng) %>% + dplyr::distinct() + + set_env_val(model, engs) + set_model_mode(model, mode) + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ +#' @importFrom vctrs vec_unique +#' @rdname set_new_model +#' @keywords internal +#' @export +set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { + check_model_exists(model) + check_eng_val(eng) + check_arg_val(parsnip) + check_arg_val(original) + check_func_val(func) + check_submodels_val(has_submodel) + + current <- get_model_env() + old_args <- get_from_env(paste0(model, "_args")) + + new_arg <- + dplyr::tibble( + engine = eng, + parsnip = parsnip, + original = original, + func = list(func), + has_submodel = has_submodel + ) + + updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new argument.", call. = FALSE) + } + + updated <- vctrs::vec_unique(updated) + set_env_val(paste0(model, "_args"), updated) + + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_dependency <- function(model, eng, pkg) { + check_model_exists(model) + check_eng_val(eng) + check_pkg_val(pkg) + + current <- get_model_env() + model_info <- get_from_env(model) + pkg_info <- get_from_env(paste0(model, "_pkgs")) + + has_engine <- + model_info %>% + dplyr::distinct(engine) %>% + dplyr::filter(engine == eng) %>% + nrow() + if (has_engine != 1) { + stop("The engine '", eng, "' has not been registered for model '", + model, "'. ", call. = FALSE) + } + + existing_pkgs <- + pkg_info %>% + dplyr::filter(engine == eng) + + if (nrow(existing_pkgs) == 0) { + pkg_info <- + pkg_info %>% + dplyr::bind_rows(tibble(engine = eng, pkg = list(pkg))) + + } else { + old_pkgs <- existing_pkgs + existing_pkgs$pkg[[1]] <- c(pkg, existing_pkgs$pkg[[1]]) + pkg_info <- + pkg_info %>% + dplyr::filter(engine != eng) %>% + dplyr::bind_rows(existing_pkgs) + } + + set_env_val(paste0(model, "_pkgs"), pkg_info) + + invisible(NULL) +} + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_dependency <- function(model) { + check_model_exists(model) + pkg_name <- paste0(model, "_pkgs") + if (!any(pkg_name != rlang::env_names(get_model_env()))) { + stop("`", model, "` does not have a dependency list in parsnip.", call. = FALSE) + } + rlang::env_get(get_model_env(), pkg_name) +} + + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_fit <- function(model, mode, eng, value) { + check_model_exists(model) + check_eng_val(eng) + check_mode_val(mode) + check_engine_val(eng) + check_fit_info(value) + + current <- get_model_env() + model_info <- get_from_env(model) + old_fits <- get_from_env(paste0(model, "_fit")) + + has_engine <- + model_info %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + if (has_engine != 1) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' has not been registered for model '", + model, "'. ", call. = FALSE) + } + + has_fit <- + old_fits %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + + if (has_fit > 0) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' already has a fit component for model '", + model, "'. ", call. = FALSE) + } + + new_fit <- + dplyr::tibble( + engine = eng, + mode = mode, + value = list(value) + ) + + updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new fit module", call. = FALSE) + } + + set_env_val( + paste0(model, "_fit"), + updated + ) + + invisible(NULL) +} + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_fit <- function(model) { + check_model_exists(model) + fit_name <- paste0(model, "_fit") + if (!any(fit_name != rlang::env_names(get_model_env()))) { + stop("`", model, "` does not have a `fit` method in parsnip.", call. = FALSE) + } + rlang::env_get(get_model_env(), fit_name) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_pred <- function(model, mode, eng, type, value) { + check_model_exists(model) + check_eng_val(eng) + check_mode_val(mode) + check_engine_val(eng) + check_pred_info(value, type) + + current <- get_model_env() + model_info <- get_from_env(model) + old_fits <- get_from_env(paste0(model, "_predict")) + + has_engine <- + model_info %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + if (has_engine != 1) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' has not been registered for model '", + model, "'. ", call. = FALSE) + } + + has_pred <- + old_fits %>% + dplyr::filter(engine == eng & mode == !!mode & type == !!type) %>% + nrow() + if (has_pred > 0) { + stop("The combination of engine '", eng, "', mode '", + mode, "', and type '", type, + "' already has a prediction component for model '", + model, "'. ", call. = FALSE) + } + + new_fit <- + dplyr::tibble( + engine = eng, + mode = mode, + type = type, + value = list(value) + ) + + updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new fit module", call. = FALSE) + } + + set_env_val(paste0(model, "_predict"), updated) + + invisible(NULL) +} + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_pred_type <- function(model, type) { + check_model_exists(model) + pred_name <- paste0(model, "_predict") + if (!any(pred_name != rlang::env_names(get_model_env()))) { + stop("`", model, "` does not have any `pred` methods in parsnip.", call. = FALSE) + } + all_preds <- rlang::env_get(get_model_env(), pred_name) + if (!any(all_preds$type == type)) { + stop("`", model, "` does not have any `", type, + "` prediction methods in parsnip.", call. = FALSE) + } + dplyr::filter(all_preds, type == !!type) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +show_model_info <- function(model) { + check_model_exists(model) + current <- get_model_env() + + cat("Information for `", model, "`\n", sep = "") + + cat( + " modes:", + paste0(get_from_env(paste0(model, "_modes")), collapse = ", "), + "\n\n" + ) + + engines <- get_from_env(model) + if (nrow(engines) > 0) { + cat(" engines: \n") + engines %>% + dplyr::mutate( + mode = format(paste0(mode, ": ")) + ) %>% + dplyr::group_by(mode) %>% + dplyr::summarize( + engine = paste0(sort(engine), collapse = ", ") + ) %>% + dplyr::mutate( + lab = paste0(" ", mode, engine, "\n") + ) %>% + dplyr::ungroup() %>% + dplyr::pull(lab) %>% + cat(sep = "") + cat("\n") + } else { + cat(" no registered engines.\n\n") + } + + args <- get_from_env(paste0(model, "_args")) + if (nrow(args) > 0) { + cat(" arguments: \n") + args %>% + dplyr::select(engine, parsnip, original) %>% + dplyr::distinct() %>% + dplyr::mutate( + engine = format(paste0(" ", engine, ": ")), + parsnip = paste0(" ", format(parsnip), " --> ", original, "\n") + ) %>% + dplyr::group_by(engine) %>% + dplyr::mutate( + engine2 = ifelse(dplyr::row_number() == 1, engine, ""), + parsnip = ifelse(dplyr::row_number() == 1, paste0("\n", parsnip), parsnip), + lab = paste0(engine2, parsnip) + ) %>% + dplyr::ungroup() %>% + dplyr::pull(lab) %>% + cat(sep = "") + cat("\n") + } else { + cat(" no registered arguments.\n\n") + } + + fits <- get_from_env(paste0(model, "_fit")) + if (nrow(fits) > 0) { + cat(" fit modules:\n") + fits %>% + dplyr::select(-value) %>% + mutate(engine = paste0(" ", engine)) %>% + as.data.frame() %>% + print(row.names = FALSE) + cat("\n") + } else { + cat(" no registered fit modules.\n\n") + } + + preds <- get_from_env(paste0(model, "_predict")) + if (nrow(preds) > 0) { + cat(" prediction modules:\n") + preds %>% + dplyr::group_by(mode, engine) %>% + dplyr::summarize(methods = paste0(sort(type), collapse = ", ")) %>% + dplyr::ungroup() %>% + mutate(mode = paste0(" ", mode)) %>% + as.data.frame() %>% + print(row.names = FALSE) + cat("\n") + } else { + cat(" no registered prediction modules.\n\n") + } + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +pred_value_template <- function(pre = NULL, post = NULL, func, ...) { + if (rlang::is_missing(func)) { + stop("Please supply a value to `func`. See `?set_pred`.", call. = FALSE) + } + list(pre = pre, post = post, func = func, args = list(...)) +} + diff --git a/R/aaa_spark_helpers.R b/R/aaa_spark_helpers.R index 4d233a760..3eba67416 100644 --- a/R/aaa_spark_helpers.R +++ b/R/aaa_spark_helpers.R @@ -20,6 +20,3 @@ format_spark_num <- function(results, object) { results <- dplyr::rename(results, pred = prediction) results } - -#' @importFrom utils globalVariables -utils::globalVariables(c(".", "predicted_label", "prediction")) diff --git a/R/arguments.R b/R/arguments.R index 1ece4725e..39246d762 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -7,49 +7,6 @@ null_value <- function(x) { res } -deharmonize <- function(args, key, engine) { - nms <- names(args) - orig_names <- key[nms, engine] - names(args) <- orig_names - args[!is.na(orig_names)] -} - -parse_engine_options <- function(x) { - res <- ll() - if (length(x) >= 2) { # in case of NULL - - arg_names <- names(x[[2]]) - arg_names <- arg_names[arg_names != ""] - - if (length(arg_names) > 0) { - # in case of list() - res <- ll() - for (i in arg_names) { - res[[i]] <- x[[2]][[i]] - } # over arg_names - } # length == 0 - } - res -} - -prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) { - nms <- names(x) - if (length(whitelist) > 0) - nms <- nms[!(nms %in% whitelist)] - for (i in nms) { - if ( - is.null(x[[i]]) | - is_null(x[[i]]) | - !(i %in% modified) | - is_missing(x[[i]]) - ) - x[[i]] <- NULL - } - if(any(names(x) == "...")) - x["..."] <- NULL - x -} - check_eng_args <- function(args, obj, core_args) { # Make sure that we are not trying to modify an argument that # is explicitly protected in the method metadata or arg_key diff --git a/R/boost_tree.R b/R/boost_tree.R index 620648fa0..61c16c87c 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -261,6 +261,7 @@ check_args.boost_tree <- function(object) { #' @param subsample Subsampling proportion of rows. #' @param ... Other options to pass to `xgb.train`. #' @return A fitted `xgboost` object. +#' @keywords internal #' @export xgb_train <- function( x, y, @@ -386,15 +387,15 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { pred <- xgb_pred(object$fit, newdata = new_data, ntreelimit = tree) # switch based on prediction type - if(object$spec$mode == "regression") { + if (object$spec$mode == "regression") { pred <- tibble(.pred = pred) nms <- names(pred) } else { if (type == "class") { - pred <- boost_tree_xgboost_data$class$post(pred, object) + pred <- object$spec$method$pred$class$post(pred, object) pred <- tibble(.pred = factor(pred, levels = object$lvl)) } else { - pred <- boost_tree_xgboost_data$classprob$post(pred, object) + pred <- object$spec$method$pred$prob$post(pred, object) pred <- as_tibble(pred) names(pred) <- paste0(".pred_", names(pred)) } @@ -432,6 +433,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { #' model in the printed output. #' @param ... Other arguments to pass. #' @return A fitted C5.0 model. +#' @keywords internal #' @export C5.0_train <- function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) { @@ -501,8 +503,3 @@ C50_by_tree <- function(tree, object, new_data, type, ...) { pred[, c(".row", "trees", nms)] } - -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c(".row")) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 559698511..98356f96b 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -1,175 +1,386 @@ +set_new_model("boost_tree") -boost_tree_arg_key <- data.frame( - xgboost = c("max_depth", "nrounds", "eta", "colsample_bytree", "min_child_weight", "gamma", "subsample"), - C5.0 = c( NA, "trials", NA, NA, "minCases", NA, "sample"), - spark = c("max_depth", "max_iter", "step_size", "feature_subset_strategy", "min_instances_per_node", "min_info_gain", "subsampling_rate"), - stringsAsFactors = FALSE, - row.names = c("tree_depth", "trees", "learn_rate", "mtry", "min_n", "loss_reduction", "sample_size") +set_model_mode("boost_tree", "classification") +set_model_mode("boost_tree", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("boost_tree", "classification", "xgboost") +set_model_engine("boost_tree", "regression", "xgboost") +set_dependency("boost_tree", "xgboost", "xgboost") + +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "trees", + original = "nrounds", + func = list(pkg = "dials", fun = "trees"), + has_submodel = TRUE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "learn_rate", + original = "eta", + func = list(pkg = "dials", fun = "learn_rate"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "mtry", + original = "colsample_bytree", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "min_n", + original = "min_child_weight", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "loss_reduction", + original = "gamma", + func = list(pkg = "dials", fun = "loss_reduction"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "sample_size", + original = "subsample", + func = list(pkg = "dials", fun = "sample_size"), + has_submodel = FALSE +) + +set_fit( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "xgb_train"), + defaults = list(nthread = 1, verbose = 0) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) ) -boost_tree_modes <- c("classification", "regression", "unknown") +set_fit( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "xgb_train"), + defaults = list(nthread = 1, verbose = 0) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + if (is.vector(x)) { + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + } else { + x <- object$lvl[apply(x, 1, which.max)] + } + x + }, + func = c(pkg = NULL, fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + if (is.vector(x)) { + x <- tibble(v1 = 1 - x, v2 = x) + } else { + x <- as_tibble(x) + } + colnames(x) <- object$lvl + x + }, + func = c(pkg = NULL, fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) -boost_tree_engines <- data.frame( - xgboost = rep(TRUE, 3), - C5.0 = c( TRUE, FALSE, TRUE), - spark = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) ) # ------------------------------------------------------------------------------ -boost_tree_xgboost_data <- - list( - libs = "xgboost", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "xgb_train"), - defaults = - list( - nthread = 1, - verbose = 0 - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - if (is.vector(x)) { - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - } else { - x <- object$lvl[apply(x, 1, which.max)] - } - x - }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - if (is.vector(x)) { - x <- tibble(v1 = 1 - x, v2 = x) - } else { - x <- as_tibble(x) - } - colnames(x) <- object$lvl - x - }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -boost_tree_C5.0_data <- - list( - libs = "C50", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "parsnip", fun = "C5.0_train"), - defaults = list() - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = list( +set_model_engine("boost_tree", "classification", "C5.0") +set_dependency("boost_tree", "C5.0", "C50") + +set_model_arg( + model = "boost_tree", + eng = "C5.0", + parsnip = "trees", + original = "trials", + func = list(pkg = "dials", fun = "trees"), + has_submodel = TRUE +) +set_model_arg( + model = "boost_tree", + eng = "C5.0", + parsnip = "min_n", + original = "minCases", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "C5.0", + parsnip = "sample_size", + original = "sample", + func = list(pkg = "dials", fun = "sample_size"), + has_submodel = FALSE +) + +set_fit( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "parsnip", fun = "C5.0_train"), + defaults = list() + ) +) + +set_pred( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "prob" ) - ) - ) - - -boost_tree_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "type"), - func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) ) +) + +set_pred( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("boost_tree", "classification", "spark") +set_model_engine("boost_tree", "regression", "spark") +set_dependency("boost_tree", "spark", "sparklyr") + +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "trees", + original = "max_iter", + func = list(pkg = "dials", fun = "trees"), + has_submodel = TRUE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "learn_rate", + original = "step_size", + func = list(pkg = "dials", fun = "learn_rate"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "mtry", + original = "feature_subset_strategy", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "min_info_gain", + original = "gamma", + func = list(pkg = "dials", fun = "loss_reduction"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "sample_size", + original = "subsampling_rate", + func = list(pkg = "dials", fun = "sample_size"), + has_submodel = FALSE +) + +set_fit( + model = "boost_tree", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + model = "boost_tree", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + model = "boost_tree", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) diff --git a/R/decision_tree.R b/R/decision_tree.R index b767b965d..afa807405 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -252,6 +252,7 @@ check_args.decision_tree <- function(object) { #' those cases. #' @param ... Other arguments to pass to either `rpart` or `rpart.control`. #' @return A fitted rpart model. +#' @keywords internal #' @export rpart_train <- function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) { diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index b4a599063..a8e8016e4 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -1,160 +1,297 @@ +set_new_model("decision_tree") -decision_tree_arg_key <- data.frame( - rpart = c( "maxdepth", "minsplit", "cp"), - C5.0 = c( NA, "minCases", NA), - spark = c("max_depth", "min_instances_per_node", NA), - stringsAsFactors = FALSE, - row.names = c("tree_depth", "min_n", "cost_complexity") +set_model_mode("decision_tree", "classification") +set_model_mode("decision_tree", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "rpart") +set_model_engine("decision_tree", "regression", "rpart") +set_dependency("decision_tree", "rpart", "rpart") + +set_model_arg( + model = "decision_tree", + eng = "rpart", + parsnip = "tree_depth", + original = "maxdepth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE ) -decision_tree_modes <- c("classification", "regression", "unknown") +set_model_arg( + model = "decision_tree", + eng = "rpart", + parsnip = "min_n", + original = "minsplit", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) -decision_tree_engines <- data.frame( - rpart = rep(TRUE, 3), - C5.0 = c(TRUE, FALSE, TRUE), - spark = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_model_arg( + model = "decision_tree", + eng = "rpart", + parsnip = "cost_complexity", + original = "cp", + func = list(pkg = "dials", fun = "cost_complexity"), + has_submodel = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + model = "decision_tree", + eng = "rpart", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rpart", fun = "rpart"), + defaults = list() + ) +) + +set_fit( + model = "decision_tree", + eng = "rpart", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rpart", fun = "rpart"), + defaults = list() + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) -decision_tree_rpart_data <- - list( - libs = "rpart", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rpart", fun = "rpart"), - defaults = - list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -decision_tree_C5.0_data <- - list( - libs = "C50", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "parsnip", fun = "C5.0_train"), - defaults = list(trials = 1) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = list( +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = NULL, fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "class" ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = list( + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(pkg = NULL, fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "C5.0") +set_dependency("decision_tree", "C5.0", "C5.0") + +set_model_arg( + model = "decision_tree", + eng = "C5.0", + parsnip = "min_n", + original = "minCases", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "parsnip", fun = "C5.0_train"), + defaults = list(trials = 1) + ) +) + +set_pred( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + + +set_pred( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "prob" ) - ) - ) - - -decision_tree_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula"), - func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) ) +) + + +set_pred( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "spark") +set_model_engine("decision_tree", "regression", "spark") +set_dependency("decision_tree", "spark", "spark") + +set_model_arg( + model = "decision_tree", + eng = "spark", + parsnip = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE +) + +set_model_arg( + model = "decision_tree", + eng = "spark", + parsnip = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "decision_tree", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + defaults = + list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + model = "decision_tree", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + defaults = + list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + model = "decision_tree", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) diff --git a/R/engines.R b/R/engines.R index a1b32de5f..327f69dc6 100644 --- a/R/engines.R +++ b/R/engines.R @@ -1,23 +1,13 @@ -get_model_info <- function (x, engine) { - cls <- specific_model(x) - nm <- paste(cls, engine, "data", sep = "_") - res <- try(get(nm), silent = TRUE) - if (inherits(res, "try-error")) - stop("Can't find model object ", nm) - res -} - specific_model <- function(x) { cls <- class(x) cls[cls != "model_spec"] } - possible_engines <- function(object, ...) { - cls <- specific_model(object) - key_df <- get(paste(cls, "engines", sep = "_")) - colnames(key_df[object$mode, , drop = FALSE]) + m_env <- get_model_env() + engs <- rlang::env_get(m_env, specific_model(object)) + unique(engs$engine) } check_engine <- function(object) { diff --git a/R/fit.R b/R/fit.R index e5b745b8f..1cb32f300 100644 --- a/R/fit.R +++ b/R/fit.R @@ -52,8 +52,6 @@ #' #' lr_mod <- logistic_reg() #' -#' lr_mod <- logistic_reg() -#' #' using_formula <- #' lr_mod %>% #' set_engine("glm") %>% @@ -125,7 +123,7 @@ fit.model_spec <- ) # populate `method` with the details for this model type - object <- get_method(object, engine = object$engine) + object <- add_methods(object, engine = object$engine) check_installs(object) @@ -215,7 +213,7 @@ fit_xy.model_spec <- ) # populate `method` with the details for this model type - object <- get_method(object, engine = object$engine) + object <- add_methods(object, engine = object$engine) check_installs(object) @@ -239,7 +237,7 @@ fit_xy.model_spec <- ... ), - data.frame_data.frame =, matrix_data.frame = + data.frame_data.frame = , matrix_data.frame = xy_xy( object = object, env = eval_env, @@ -249,7 +247,7 @@ fit_xy.model_spec <- ), # heterogenous combinations - matrix_formula =, data.frame_formula = + matrix_formula = , data.frame_formula = xy_form( object = object, env = eval_env, @@ -367,7 +365,7 @@ check_xy_interface <- function(x, y, cl, model) { print.model_fit <- function(x, ...) { cat("parsnip model object\n\n") - if(inherits(x$fit, "try-error")) { + if (inherits(x$fit, "try-error")) { cat("Model fit failed with error:\n", x$fit, "\n") } else { print(x$fit, ...) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 74f614ede..940daebe7 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -181,9 +181,3 @@ xy_form <- function(object, env, control, ...) { res } -# ------------------------------------------------------------------------------ -## - - -#' @importFrom utils globalVariables -utils::globalVariables("data") diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 6225d1023..54dbd3338 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -1,265 +1,369 @@ +set_new_model("linear_reg") -linear_reg_arg_key <- data.frame( - lm = c( NA, NA), - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - stan = c( NA, NA), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("linear_reg", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "lm") +set_dependency("linear_reg", "lm", "stats") + +set_fit( + model = "linear_reg", + eng = "lm", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "stats", fun = "lm"), + defaults = list() + ) ) -linear_reg_modes <- "regression" +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ) +) -linear_reg_engines <- data.frame( - lm = TRUE, - glmnet = TRUE, - spark = TRUE, - stan = TRUE, - keras = TRUE, - row.names = c("regression") +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + tibble::as_tibble(results) %>% + dplyr::select(-fit) %>% + setNames(c(".pred_lower", ".pred_upper")) + }, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "confidence", + level = expr(level), + type = "response" + ) + ) +) +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + tibble::as_tibble(results) %>% + dplyr::select(-fit) %>% + setNames(c(".pred_lower", ".pred_upper")) + }, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "prediction", + level = expr(level), + type = "response" + ) + ) ) +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) # ------------------------------------------------------------------------------ -linear_reg_lm_data <- - list( - libs = "stats", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "stats", fun = "lm"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - tibble::as_tibble(results) %>% - dplyr::select(-fit) %>% - setNames(c(".pred_lower", ".pred_upper")) - }, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "confidence", - level = expr(level), - type = "response" - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - tibble::as_tibble(results) %>% - dplyr::select(-fit) %>% - setNames(c(".pred_lower", ".pred_upper")) - }, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "prediction", - level = expr(level), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ) +set_model_engine("linear_reg", "regression", "glmnet") +set_dependency("linear_reg", "glmnet", "glmnet") + +set_model_arg( + model = "linear_reg", + eng = "glmnet", + parsnip = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "linear_reg", + eng = "glmnet", + parsnip = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "gaussian") ) +) -# Note: For glmnet, you will need to make model-specific predict methods. -# See linear_reg.R -linear_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "gaussian" - ) - ), - numeric = list( - pre = NULL, - post = organize_glmnet_pred, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data)), - type = "response", - s = expr(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data)) - ) - ) +set_pred( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = organize_glmnet_pred, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newx = expr(as.matrix(new_data)), + type = "response", + s = expr(object$spec$args$penalty) + ) ) +) -linear_reg_stan_data <- - list( - libs = "rstanarm", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rstanarm", fun = "stan_glm"), - defaults = - list( - family = expr(stats::gaussian) - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - res <- - tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level, - lower = FALSE - ), - ) - if(object$spec$method$confint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_linpred"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - transform = TRUE, - seed = expr(sample.int(10^5, 1)) - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - res <- - tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level, - lower = FALSE - ), - ) - if(object$spec$method$predint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ) +set_pred( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = expr(object$fit), + newx = expr(as.matrix(new_data))) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "stan") +set_dependency("linear_reg", "stan", "rstanarm") + +set_fit( + model = "linear_reg", + eng = "stan", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rstanarm", fun = "stan_glm"), + defaults = list(family = expr(stats::gaussian)) ) +) -#' @importFrom dplyr select rename -linear_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_linear_regression") - ), - numeric = list( - pre = NULL, - post = function(results, object) { - results <- dplyr::rename(results, pred = prediction) - results <- dplyr::select(results, pred) - results - }, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = expr(object$fit), - dataset = expr(new_data) +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) + +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res <- + tibble( + .pred_lower = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level + ), + .pred_upper = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level, + lower = FALSE + ), ) - ) + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_linpred"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + transform = TRUE, + seed = expr(sample.int(10^5, 1)) + ) ) +) -linear_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + res <- + tibble( + .pred_lower = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level + ), + .pred_upper = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level, + lower = FALSE + ), ) - ) + if (object$spec$method$pred$pred_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) + +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "spark") +set_dependency("linear_reg", "spark", "sparklyr") + +set_model_arg( + model = "linear_reg", + eng = "spark", + parsnip = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "linear_reg", + eng = "spark", + parsnip = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + + +set_fit( + model = "linear_reg", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_linear_regression"), + defaults = list() + ) +) + +set_pred( + model = "linear_reg", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) { + results <- dplyr::rename(results, pred = prediction) + results <- dplyr::select(results, pred) + results + }, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = expr(object$fit), dataset = expr(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + + +set_model_engine("linear_reg", "regression", "keras") +set_dependency("linear_reg", "keras", "keras") +set_dependency("linear_reg", "keras", "magrittr") + +set_fit( + model = "linear_reg", + eng = "keras", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + model = "linear_reg", + eng = "keras", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) + ) +) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 566a0c4ad..1bc9062b6 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -355,8 +355,3 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) { predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } - -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c("group", ".pred")) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 6fa1e0c94..d04f0100e 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -1,354 +1,514 @@ +set_new_model("logistic_reg") -logistic_reg_arg_key <- data.frame( - glm = c( NA, NA), - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - stan = c( NA, NA), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("logistic_reg", "classification") + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "glm") +set_dependency("logistic_reg", "glm", "stats") + +set_fit( + model = "logistic_reg", + eng = "glm", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "stats", fun = "glm"), + defaults = list(family = expr(stats::binomial)) + ) ) -logistic_reg_modes <- "classification" +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = prob_to_class_2, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) -logistic_reg_engines <- data.frame( - glm = TRUE, - glmnet = TRUE, - spark = TRUE, - stan = TRUE, - keras = TRUE, - row.names = c("classification") +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) -#' @importFrom stats qt -logistic_reg_glm_data <- - list( - libs = "stats", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "stats", fun = "glm"), - defaults = - list( - family = expr(stats::binomial) - ) - ), - class = list( - pre = NULL, - post = prob_to_class_2, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res_2 <- + tibble( + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res_2 <- - tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$confint$extras$std_error) - res$.std_error <- results$se.fit - res - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - se.fit = TRUE, - type = "link" - ) - ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- results$se.fit + res + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + se.fit = TRUE, + type = "link" + ) ) +) -# Note: For glmnet, you will need to make model-specific predict methods. -# See logistic_reg.R -logistic_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "binomial" - ) - ), - class = list( - pre = NULL, - post = organize_glmnet_class, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - classprob = list( - pre = NULL, - post = organize_glmnet_prob, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) - ) +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "glmnet") +set_dependency("logistic_reg", "glmnet", "glmnet") + +set_model_arg( + model = "logistic_reg", + eng = "glmnet", + parsnip = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "logistic_reg", + eng = "glmnet", + parsnip = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "binomial") ) +) -logistic_reg_stan_data <- - list( - libs = "rstanarm", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rstanarm", fun = "stan_glm"), - defaults = - list( - family = expr(stats::binomial) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - x <- object$fit$family$linkinv(x) - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - unname(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- object$fit$family$linkinv(x) - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - res_2 <- - tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level, - lower = FALSE - ), - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$confint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_linpred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - transform = TRUE, - seed = expr(sample.int(10^5, 1)) - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - res_2 <- - tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level, - lower = FALSE - ), - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$predint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ) + +set_pred( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = organize_glmnet_class, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) ) +) +set_pred( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = organize_glmnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) + ) +) -logistic_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "binomial" - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) +set_pred( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "spark") +set_dependency("logistic_reg", "spark", "sparklyr") + +set_model_arg( + model = "logistic_reg", + eng = "spark", + parsnip = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) -logistic_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) +set_model_arg( + model = "logistic_reg", + eng = "spark", + parsnip = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "logistic_reg", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), + defaults = + list( + family = "binomial" + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "keras") +set_dependency("logistic_reg", "keras", "keras") +set_dependency("logistic_reg", "keras", "magrittr") + +set_model_arg( + model = "logistic_reg", + eng = "keras", + parsnip = "decay", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) + +set_fit( + model = "logistic_reg", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + model = "logistic_reg", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "stan") +set_dependency("logistic_reg", "stan", "rstanarm") + +set_fit( + model = "logistic_reg", + eng = "stan", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rstanarm", fun = "stan_glm"), + defaults = list(family = expr(stats::binomial)) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- object$fit$family$linkinv(x) + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + unname(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- object$fit$family$linkinv(x) + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res_2 <- + tibble( + lo = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level + ), + hi = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level, + lower = FALSE + ), ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_linpred"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + transform = TRUE, + seed = expr(sample.int(10^5, 1)) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + res_2 <- + tibble( + lo = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level + ), + hi = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level, + lower = FALSE + ), ) - ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$pred$pred_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) +) diff --git a/R/mars_data.R b/R/mars_data.R index 0c1076e68..a4a84e268 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -1,77 +1,152 @@ -mars_arg_key <- data.frame( - earth = c( "nprune", "degree", "pmethod"), - stringsAsFactors = FALSE, - row.names = c("num_terms", "prod_degree", "prune_method") +set_new_model("mars") + +set_model_mode("mars", "classification") +set_model_mode("mars", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("mars", "classification", "earth") +set_model_engine("mars", "regression", "earth") +set_dependency("mars", "earth", "earth") + +set_model_arg( + model = "mars", + eng = "earth", + parsnip = "num_terms", + original = "nprune", + func = list(pkg = "dials", fun = "num_terms"), + has_submodel = FALSE +) +set_model_arg( + model = "mars", + eng = "earth", + parsnip = "prod_degree", + original = "degree", + func = list(pkg = "dials", fun = "prod_degree"), + has_submodel = FALSE +) +set_model_arg( + model = "mars", + eng = "earth", + parsnip = "prune_method", + original = "pmethod", + func = list(pkg = "dials", fun = "prune_method"), + has_submodel = FALSE ) -mars_modes <- c("classification", "regression", "unknown") +set_fit( + model = "mars", + eng = "earth", + mode = "regression", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "earth", fun = "earth"), + defaults = list(keepxy = TRUE) + ) +) -mars_engines <- data.frame( - earth = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_fit( + model = "mars", + eng = "earth", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "earth", fun = "earth"), + defaults = list(keepxy = TRUE) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "mars", + eng = "earth", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "mars", + eng = "earth", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) -mars_earth_data <- - list( - libs = "earth", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "earth", fun = "earth"), - defaults = list(keepxy = TRUE) - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - x <- ifelse(x[,1] >= 0.5, object$lvl[2], object$lvl[1]) - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- x[,1] - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_pred( + model = "mars", + eng = "earth", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- ifelse(x[, 1] >= 0.5, object$lvl[2], object$lvl[1]) + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) +) +set_pred( + model = "mars", + eng = "earth", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- x[, 1] + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "mars", + eng = "earth", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) diff --git a/R/misc.R b/R/misc.R index 3885e7817..7dda7d0df 100644 --- a/R/misc.R +++ b/R/misc.R @@ -132,14 +132,6 @@ make_call <- function(fun, ns, args, ...) { out } -resolve_args <- function(args, ...) { - for (i in seq(along = args)) { - if (!is_missing_arg(args[[i]])) - args[[i]] <- eval_tidy(args[[i]], ...) - } - args -} - levels_from_formula <- function(f, dat) { if (inherits(dat, "tbl_spark")) res <- NULL @@ -152,8 +144,8 @@ is_spark <- function(x) isTRUE(unname(x$method$fit$func["pkg"] == "sparklyr")) -show_fit <- function(mod, eng) { - mod <- translate(x = mod, engine = eng) +show_fit <- function(model, eng) { + mod <- translate(x = model, engine = eng) fit_call <- show_call(mod) call_text <- deparse(fit_call) call_text <- paste0(call_text, collapse = "\n") @@ -201,7 +193,7 @@ update_dot_check <- function(...) { # ------------------------------------------------------------------------------ new_model_spec <- function(cls, args, eng_args, mode, method, engine) { - spec_modes <- get(paste0(cls, "_modes")) + spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) if (!(mode %in% spec_modes)) stop("`mode` should be one of: ", paste0("'", spec_modes, "'", collapse = ", "), diff --git a/R/mlp.R b/R/mlp.R index c8c94054e..3fe96631f 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -229,3 +229,189 @@ check_args.mlp <- function(object) { invisible(object) } + +# keras wrapper for feed-forward nnet + +class2ind <- function (x, drop2nd = FALSE) { + if (!is.factor(x)) + stop("`x` should be a factor") + y <- model.matrix( ~ x - 1) + colnames(y) <- gsub("^x", "", colnames(y)) + attributes(y)$assign <- NULL + attributes(y)$contrasts <- NULL + if (length(levels(x)) == 2 & drop2nd) { + y <- y[, 1] + } + y +} + + +# ------------------------------------------------------------------------------ + +#' Simple interface to MLP models via keras +#' +#' Instead of building a `keras` model sequentially, `keras_mlp` can be used to +#' create a feedforward network with a single hidden layer. Regularization is +#' via either weight decay or dropout. +#' +#' @param x A data frame or matrix of predictors +#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. +#' @param hidden_units An integer for the number of hidden units. +#' @param decay A non-negative real number for the amount of weight decay. Either +#' this parameter _or_ `dropout` can specified. +#' @param dropout The proportion of parameters to set to zero. Either +#' this parameter _or_ `decay` can specified. +#' @param epochs An integer for the number of passes through the data. +#' @param act A character string for the type of activation function between layers. +#' @param seeds A vector of three positive integers to control randomness of the +#' calculations. +#' @param ... Currently ignored. +#' @return A `keras` model object. +#' @keywords internal +#' @export +keras_mlp <- + function(x, y, + hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax", + seeds = sample.int(10^5, size = 3), + ...) { + + if(decay > 0 & dropout > 0) + stop("Please use either dropoput or weight decay.", call. = FALSE) + + if (!is.matrix(x)) + x <- as.matrix(x) + + if(is.character(y)) + y <- as.factor(y) + factor_y <- is.factor(y) + + if (factor_y) + y <- class2ind(y) + else { + if (isTRUE(ncol(y) > 1)) + y <- as.matrix(y) + else + y <- matrix(y, ncol = 1) + } + + model <- keras::keras_model_sequential() + if(decay > 0) { + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_regularizer = keras::regularizer_l2(decay), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) + } else { + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) + } + if(dropout > 0) + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) %>% + keras::layer_dropout(rate = dropout, seed = seeds[2]) + + if (factor_y) + model <- model %>% + keras::layer_dense( + units = ncol(y), + activation = 'softmax', + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) + else + model <- model %>% + keras::layer_dense( + units = ncol(y), + activation = 'linear', + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) + + arg_values <- parse_keras_args(...) + compile_call <- expr( + keras::compile(object = model) + ) + if(!any(names(arg_values$compile) == "loss")) + compile_call$loss <- + if(factor_y) "binary_crossentropy" else "mse" + if(!any(names(arg_values$compile) == "optimizer")) + compile_call$optimizer <- "adam" + for(arg in names(arg_values$compile)) + compile_call[[arg]] <- arg_values$compile[[arg]] + + model <- eval_tidy(compile_call) + + fit_call <- expr( + keras::fit(object = model) + ) + fit_call$x <- quote(x) + fit_call$y <- quote(y) + fit_call$epochs <- epochs + for(arg in names(arg_values$fit)) + fit_call[[arg]] <- arg_values$fit[[arg]] + + history <- eval_tidy(fit_call) + model + } + + +nnet_softmax <- function(results, object) { + if (ncol(results) == 1) + results <- cbind(1 - results, results) + + results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) + results <- as_tibble(t(results)) + names(results) <- paste0(".pred_", object$lvl) + results +} + +parse_keras_args <- function(...) { + exclusions <- c("object", "x", "y", "validation_data", "epochs") + fit_args <- c( + 'batch_size', + 'verbose', + 'callbacks', + 'view_metrics', + 'validation_split', + 'validation_data', + 'shuffle', + 'class_weight', + 'sample_weight', + 'initial_epoch', + 'steps_per_epoch', + 'validation_steps' + ) + compile_args <- c( + 'optimizer', + 'loss', + 'metrics', + 'loss_weights', + 'sample_weight_mode', + 'weighted_metrics', + 'target_tensors' + ) + dots <- list(...) + dots <- dots[!(names(dots) %in% exclusions)] + + list( + fit = dots[names(dots) %in% fit_args], + compile = dots[names(dots) %in% compile_args] + ) +} + +mlp_num_weights <- function(p, hidden_units, classes) { + ((p+1) * hidden_units) + ((hidden_units+1) * classes) +} + + diff --git a/R/mlp_data.R b/R/mlp_data.R index 724035e7e..4df47377a 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -1,312 +1,313 @@ -mlp_arg_key <- data.frame( - nnet = c("size", "decay", NA_character_, "maxit", NA_character_), - keras = c("hidden_units", "penalty", "dropout", "epochs", "activation"), - stringsAsFactors = FALSE, - row.names = c("hidden_units", "penalty", "dropout", "epochs", "activation") -) - -mlp_modes <- c("classification", "regression", "unknown") +set_new_model("mlp") -mlp_engines <- data.frame( - nnet = c(TRUE, TRUE, FALSE), - keras = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") -) +set_model_mode("mlp", "classification") +set_model_mode("mlp", "regression") # ------------------------------------------------------------------------------ -mlp_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ) - ) +set_model_engine("mlp", "classification", "keras") +set_model_engine("mlp", "regression", "keras") +set_dependency("mlp", "keras", "keras") +set_dependency("mlp", "keras", "magrittr") +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "hidden_units", + original = "hidden_units", + func = list(pkg = "dials", fun = "hidden_units"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "penalty", + original = "penalty", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "dropout", + original = "dropout", + func = list(pkg = "dials", fun = "dropout"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "epochs", + original = "epochs", + func = list(pkg = "dials", fun = "epochs"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "activation", + original = "activation", + func = list(pkg = "dials", fun = "activation"), + has_submodel = FALSE +) -nnet_softmax <- function(results, object) { - if (ncol(results) == 1) - results <- cbind(1 - results, results) - - results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) - results <- as_tibble(t(results)) - names(results) <- paste0(".pred_", object$lvl) - results -} -mlp_nnet_data <- - list( - libs = "nnet", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "nnet", fun = "nnet"), - defaults = list(trace = FALSE) - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = nnet_softmax, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_fit( + model = "mlp", + eng = "keras", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list() ) +) +set_fit( + model = "mlp", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list() + ) +) -# ------------------------------------------------------------------------------ - -# keras wrapper for feed-forward nnet - -class2ind <- function (x, drop2nd = FALSE) { - if (!is.factor(x)) - stop("`x` should be a factor") - y <- model.matrix( ~ x - 1) - colnames(y) <- gsub("^x", "", colnames(y)) - attributes(y)$assign <- NULL - attributes(y)$contrasts <- NULL - if (length(levels(x)) == 2 & drop2nd) { - y <- y[, 1] - } - y -} - - -#' Simple interface to MLP models via keras -#' -#' Instead of building a `keras` model sequentially, `keras_mlp` can be used to -#' create a feedforward network with a single hidden layer. Regularization is -#' via either weight decay or dropout. -#' -#' @param x A data frame or matrix of predictors -#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. -#' @param hidden_units An integer for the number of hidden units. -#' @param decay A non-negative real number for the amount of weight decay. Either -#' this parameter _or_ `dropout` can specified. -#' @param dropout The proportion of parameters to set to zero. Either -#' this parameter _or_ `decay` can specified. -#' @param epochs An integer for the number of passes through the data. -#' @param act A character string for the type of activation function between layers. -#' @param seeds A vector of three positive integers to control randomness of the -#' calculations. -#' @param ... Currently ignored. -#' @return A `keras` model object. -#' @export -keras_mlp <- - function(x, y, - hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax", - seeds = sample.int(10^5, size = 3), - ...) { - - if(decay > 0 & dropout > 0) - stop("Please use either dropoput or weight decay.", call. = FALSE) - - if (!is.matrix(x)) - x <- as.matrix(x) - - if(is.character(y)) - y <- as.factor(y) - factor_y <- is.factor(y) +set_pred( + model = "mlp", + eng = "keras", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) - if (factor_y) - y <- class2ind(y) - else { - if (isTRUE(ncol(y) > 1)) - y <- as.matrix(y) - else - y <- matrix(y, ncol = 1) - } +set_pred( + model = "mlp", + eng = "keras", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) - model <- keras::keras_model_sequential() - if(decay > 0) { - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_regularizer = keras::regularizer_l2(decay), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) - } else { - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) - } - if(dropout > 0) - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) %>% - keras::layer_dropout(rate = dropout, seed = seeds[2]) +) - if (factor_y) - model <- model %>% - keras::layer_dense( - units = ncol(y), - activation = 'softmax', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) +set_pred( + model = "mlp", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) ) - else - model <- model %>% - keras::layer_dense( - units = ncol(y), - activation = 'linear', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) +) + +set_pred( + model = "mlp", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) ) + ) +) - arg_values <- parse_keras_args(...) - compile_call <- expr( - keras::compile(object = model) - ) - if(!any(names(arg_values$compile) == "loss")) - compile_call$loss <- - if(factor_y) "binary_crossentropy" else "mse" - if(!any(names(arg_values$compile) == "optimizer")) - compile_call$optimizer <- "adam" - for(arg in names(arg_values$compile)) - compile_call[[arg]] <- arg_values$compile[[arg]] +set_pred( + model = "mlp", + eng = "keras", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) - model <- eval_tidy(compile_call) +# ------------------------------------------------------------------------------ - fit_call <- expr( - keras::fit(object = model) - ) - fit_call$x <- quote(x) - fit_call$y <- quote(y) - fit_call$epochs <- epochs - for(arg in names(arg_values$fit)) - fit_call[[arg]] <- arg_values$fit[[arg]] +set_model_engine("mlp", "classification", "nnet") +set_model_engine("mlp", "regression", "nnet") +set_dependency("mlp", "nnet", "nnet") - history <- eval_tidy(fit_call) - model - } +set_model_arg( + model = "mlp", + eng = "nnet", + parsnip = "hidden_units", + original = "size", + func = list(pkg = "dials", fun = "hidden_units"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "nnet", + parsnip = "penalty", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "nnet", + parsnip = "epochs", + original = "maxit", + func = list(pkg = "dials", fun = "epochs"), + has_submodel = FALSE +) +set_fit( + model = "mlp", + eng = "nnet", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "nnet"), + defaults = list(trace = FALSE) + ) +) -parse_keras_args <- function(...) { - exclusions <- c("object", "x", "y", "validation_data", "epochs") - fit_args <- c( - 'batch_size', - 'verbose', - 'callbacks', - 'view_metrics', - 'validation_split', - 'validation_data', - 'shuffle', - 'class_weight', - 'sample_weight', - 'initial_epoch', - 'steps_per_epoch', - 'validation_steps' +set_fit( + model = "mlp", + eng = "nnet", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "nnet"), + defaults = list(trace = FALSE) ) - compile_args <- c( - 'optimizer', - 'loss', - 'metrics', - 'loss_weights', - 'sample_weight_mode', - 'weighted_metrics', - 'target_tensors' +) + +set_pred( + model = "mlp", + eng = "nnet", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) - dots <- list(...) - dots <- dots[!(names(dots) %in% exclusions)] +) - list( - fit = dots[names(dots) %in% fit_args], - compile = dots[names(dots) %in% compile_args] +set_pred( + model = "mlp", + eng = "nnet", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) -} -mlp_num_weights <- function(p, hidden_units, classes) - ((p+1) * hidden_units) + ((hidden_units+1) * classes) +) +set_pred( + model = "mlp", + eng = "nnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) + ) +) +set_pred( + model = "mlp", + eng = "nnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = nnet_softmax, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) +set_pred( + model = "mlp", + eng = "nnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index b9aa2f40c..71e367489 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -322,7 +322,3 @@ check_glmnet_lambda <- function(dat, object) { dat } -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c("group", ".pred")) diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 186291003..7b7b16b3b 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -1,138 +1,228 @@ +set_new_model("multinom_reg") -multinom_reg_arg_key <- data.frame( - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("multinom_reg", "classification") + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "glmnet") +set_dependency("multinom_reg", "glmnet", "glmnet") + +set_model_arg( + model = "multinom_reg", + eng = "glmnet", + parsnip = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "multinom_reg", + eng = "glmnet", + parsnip = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "multinomial") + ) +) + + +set_pred( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "class", + value = list( + pre = check_glmnet_lambda, + post = organize_multnet_class, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "class", + s = quote(object$spec$args$penalty) + ) + ) ) -multinom_reg_modes <- "classification" +set_pred( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "prob", + value = list( + pre = check_glmnet_lambda, + post = organize_multnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) + ) +) -multinom_reg_engines <- data.frame( - glmnet = TRUE, - spark = TRUE, - keras = TRUE, - row.names = c("classification") +set_pred( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) + ) ) # ------------------------------------------------------------------------------ -multinom_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "multinomial" - ) - ), - class = list( - pre = check_glmnet_lambda, - post = organize_multnet_class, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "class", - s = quote(object$spec$args$penalty) - ) - ), - classprob = list( - pre = check_glmnet_lambda, - post = organize_multnet_prob, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) - ) +set_model_engine("multinom_reg", "classification", "spark") +set_dependency("multinom_reg", "spark", "sparklyr") + +set_model_arg( + model = "multinom_reg", + eng = "spark", + parsnip = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "multinom_reg", + eng = "spark", + parsnip = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "multinom_reg", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), + defaults = list(family = "multinomial") ) +) -multinom_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "multinomial" - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) +set_pred( + model = "multinom_reg", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) +) -multinom_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ) +set_pred( + model = "multinom_reg", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "keras") +set_dependency("multinom_reg", "keras", "keras") +set_dependency("multinom_reg", "keras", "magrittr") + +set_model_arg( + model = "multinom_reg", + eng = "keras", + parsnip = "decay", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) + + +set_fit( + model = "multinom_reg", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + model = "multinom_reg", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list(object = quote(object$fit), + x = quote(as.matrix(new_data))) + ) +) + +set_pred( + model = "multinom_reg", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list(object = quote(object$fit), + x = quote(as.matrix(new_data))) + ) +) diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 52c9a9f0e..4f0cbb165 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -29,7 +29,8 @@ #' `"classification"`. #' #' @param neighbors A single integer for the number of neighbors -#' to consider (often called `k`). +#' to consider (often called `k`). For \pkg{kknn}, a value of 5 +#' is used if `neighbors` is not specified. #' #' @param weight_func A *single* character for the type of kernel function used #' to weight distances between samples. Valid choices are: `"rectangular"`, @@ -54,7 +55,7 @@ #' #' \pkg{kknn} (classification or regression) #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(mode = "regression"), "kknn")} #' #' @note #' For `kknn`, the underlying modeling function used is a restricted @@ -148,13 +149,32 @@ check_args.nearest_neighbor <- function(object) { args <- lapply(object$args, rlang::eval_tidy) - if(is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { + if (is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { stop("`neighbors` must be a length 1 positive integer.", call. = FALSE) } - if(is.character(args$weight_func) && length(args$weight_func) > 1) { + if (is.character(args$weight_func) && length(args$weight_func) > 1) { stop("The length of `weight_func` must be 1.", call. = FALSE) } invisible(object) } + +# ------------------------------------------------------------------------------ + +#' @export +translate.nearest_neighbor <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'kknn'` for translation.") + engine <- "kknn" + } + x <- translate.default(x, engine, ...) + + if (engine == "kknn") { + if (!any(names(x$method$fit$args) == "ks") || + is_missing_arg(x$method$fit$args$ks)) { + x$method$fit$args$ks <- 5 + } + } + x +} diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 0191d8614..2a85f70d0 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -1,90 +1,171 @@ -nearest_neighbor_arg_key <- data.frame( - kknn = c("ks", "kernel", "distance"), - row.names = c("neighbors", "weight_func", "dist_power"), - stringsAsFactors = FALSE + +set_new_model("nearest_neighbor") + +set_model_mode("nearest_neighbor", "classification") +set_model_mode("nearest_neighbor", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("nearest_neighbor", "classification", "kknn") +set_model_engine("nearest_neighbor", "regression", "kknn") +set_dependency("nearest_neighbor", "kknn", "kknn") + +set_model_arg( + model = "nearest_neighbor", + eng = "kknn", + parsnip = "neighbors", + original = "ks", + func = list(pkg = "dials", fun = "neighbors"), + has_submodel = FALSE +) +set_model_arg( + model = "nearest_neighbor", + eng = "kknn", + parsnip = "weight_func", + original = "kernel", + func = list(pkg = "dials", fun = "weight_func"), + has_submodel = FALSE +) +set_model_arg( + model = "nearest_neighbor", + eng = "kknn", + parsnip = "dist_power", + original = "distance", + func = list(pkg = "dials", fun = "distance"), + has_submodel = FALSE ) -nearest_neighbor_modes <- c("classification", "regression", "unknown") +set_fit( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "kknn", fun = "train.kknn"), + defaults = list() + ) +) -nearest_neighbor_engines <- data.frame( - kknn = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_fit( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "kknn", fun = "train.kknn"), + defaults = list() + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + type = "numeric", + value = list( + # seems unnecessary here as the predict_numeric catches it based on the + # model mode + pre = function(x, object) { + if (object$fit$response != "continuous") { + stop("`kknn` model does not appear to use numeric predictions. Was ", + "the model fit with a continuous response variable?", + call. = FALSE) + } + x + }, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) + +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "class", + value = list( + pre = function(x, object) { + if (!(object$fit$response %in% c("ordinal", "nominal"))) { + stop("`kknn` model does not appear to use class predictions. Was ", + "the model fit with a factor response variable?", + call. = FALSE) + } + x + }, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) + +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "prob", + value = list( + pre = function(x, object) { + if (!(object$fit$response %in% c("ordinal", "nominal"))) { + stop("`kknn` model does not appear to use class predictions. Was ", + "the model fit with a factor response variable?", + call. = FALSE) + } + x + }, + post = function(result, object) as_tibble(result), + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) -nearest_neighbor_kknn_data <- - list( - libs = "kknn", - fit = list( - interface = "formula", - protect = c("formula", "data", "kmax"), # kmax is not allowed - func = c(pkg = "kknn", fun = "train.kknn"), - defaults = list() - ), - numeric = list( - # seems unnecessary here as the predict_numeric catches it based on the - # model mode - pre = function(x, object) { - if (object$fit$response != "continuous") { - stop("`kknn` model does not appear to use numeric predictions. Was ", - "the model fit with a continuous response variable?", - call. = FALSE) - } - x - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - class = list( - pre = function(x, object) { - if (!(object$fit$response %in% c("ordinal", "nominal"))) { - stop("`kknn` model does not appear to use class predictions. Was ", - "the model fit with a factor response variable?", - call. = FALSE) - } - x - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - classprob = list( - pre = function(x, object) { - if (!(object$fit$response %in% c("ordinal", "nominal"))) { - stop("`kknn` model does not appear to use class predictions. Was ", - "the model fit with a factor response variable?", - call. = FALSE) - } - x - }, - post = function(result, object) as_tibble(result), - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) +) diff --git a/R/nullmodel.R b/R/nullmodel.R index c37ecfead..10772562d 100644 --- a/R/nullmodel.R +++ b/R/nullmodel.R @@ -160,6 +160,7 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { #' @export null_model <- function(mode = "classification") { + null_model_modes <- unique(get_model_env()$null_model$mode) # Check for correct mode if (!(mode %in% null_model_modes)) stop("`mode` should be one of: ", diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index 80380a98a..1aba4e6bf 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -1,72 +1,128 @@ -null_model_arg_key <- data.frame( - parsnip = NULL, - row.names = NULL, - stringsAsFactors = FALSE +set_new_model("null_model") + +set_model_mode("null_model", "classification") +set_model_mode("null_model", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("null_model", "classification", "parsnip") +set_model_engine("null_model", "regression", "parsnip") +set_dependency("null_model", "parsnip", "parsnip") + +set_fit( + model = "null_model", + eng = "parsnip", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(fun = "nullmodel"), + defaults = list() + ) ) -null_model_modes <- c("classification", "regression", "unknown") +set_fit( + model = "null_model", + eng = "parsnip", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(fun = "nullmodel"), + defaults = list() + ) +) -null_model_engines <- data.frame( - parsnip = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_pred( + model = "null_model", + eng = "parsnip", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "null_model", + eng = "parsnip", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) + ) +) -null_model_parsnip_data <- - list( - libs = "parsnip", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(fun = "nullmodel"), - defaults = list() - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - str(as_tibble(x)) - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) - ), - raw = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "raw" - ) +set_pred( + model = "null_model", + eng = "parsnip", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" ) - ) + ) +) + +set_pred( + model = "null_model", + eng = "parsnip", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + str(as_tibble(x)) + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + model = "null_model", + eng = "parsnip", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "raw" + ) + ) +) + diff --git a/R/predict.R b/R/predict.R index 7cea3e0b7..414096d1a 100644 --- a/R/predict.R +++ b/R/predict.R @@ -154,9 +154,6 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) res } -pred_types <- - c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile") - #' @importFrom glue glue_collapse check_pred_type <- function(object, type) { if (is.null(type)) { diff --git a/R/predict_class.R b/R/predict_class.R index 0e292a8b3..91c84fc02 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -13,7 +13,7 @@ predict_class.model_fit <- function(object, new_data, ...) { stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "class")) + if (!any(names(object$spec$method$pred) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -24,17 +24,17 @@ predict_class.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$class$pre)) - new_data <- object$spec$method$class$pre(new_data, object) + if (!is.null(object$spec$method$pred$class$pre)) + new_data <- object$spec$method$pred$class$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$class) + pred_call <- make_pred_call(object$spec$method$pred$class) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$class$post)) { - res <- object$spec$method$class$post(res, object) + if (!is.null(object$spec$method$pred$class$post)) { + res <- object$spec$method$pred$class$post(res, object) } # coerce levels to those in `object` diff --git a/R/predict_classprob.R b/R/predict_classprob.R index d902e7735..1dbbe5328 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -10,7 +10,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) { stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "classprob")) + if (!any(names(object$spec$method$pred) == "prob")) stop("No class probability module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -21,17 +21,17 @@ predict_classprob.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$classprob$pre)) - new_data <- object$spec$method$classprob$pre(new_data, object) + if (!is.null(object$spec$method$pred$prob$pre)) + new_data <- object$spec$method$pred$prob$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$classprob) + pred_call <- make_pred_call(object$spec$method$pred$prob) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$classprob$post)) { - res <- object$spec$method$classprob$post(res, object) + if (!is.null(object$spec$method$pred$prob$post)) { + res <- object$spec$method$pred$prob$post(res, object) } # check and sort names diff --git a/R/predict_interval.R b/R/predict_interval.R index f38838e00..b492ef866 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -10,7 +10,7 @@ # @export predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$confint)) + if (is.null(object$spec$method$pred$conf_int)) stop("No confidence interval method defined for this ", "engine.", call. = FALSE) @@ -22,19 +22,19 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$confint$pre)) - new_data <- object$spec$method$confint$pre(new_data, object) + if (!is.null(object$spec$method$pred$conf_int$pre)) + new_data <- object$spec$method$pred$conf_int$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$confint$extras <- + object$spec$method$pred$conf_int$extras <- list(level = level, std_error = std_error) - pred_call <- make_pred_call(object$spec$method$confint) + pred_call <- make_pred_call(object$spec$method$pred$conf_int) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$confint$post)) { - res <- object$spec$method$confint$post(res, object) + if (!is.null(object$spec$method$pred$conf_int$post)) { + res <- object$spec$method$pred$conf_int$post(res, object) } attr(res, "level") <- level @@ -59,7 +59,7 @@ predict_confint <- function(object, ...) # @export predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$predint)) + if (is.null(object$spec$method$pred$pred_int)) stop("No prediction interval method defined for this ", "engine.", call. = FALSE) @@ -71,20 +71,20 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$predint$pre)) - new_data <- object$spec$method$predint$pre(new_data, object) + if (!is.null(object$spec$method$pred$pred_int$pre)) + new_data <- object$spec$method$pred$pred_int$pre(new_data, object) # create prediction call # Pass some extra arguments to be used in post-processor - object$spec$method$predint$extras <- + object$spec$method$pred$pred_int$extras <- list(level = level, std_error = std_error) - pred_call <- make_pred_call(object$spec$method$predint) + pred_call <- make_pred_call(object$spec$method$pred$pred_int) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$predint$post)) { - res <- object$spec$method$predint$post(res, object) + if (!is.null(object$spec$method$pred$pred_int$post)) { + res <- object$spec$method$pred$pred_int$post(res, object) } attr(res, "level") <- level diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 3a509546b..970107049 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -11,7 +11,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) { "Use `predict_class()` or `predict_classprob()` for ", "classification models.", call. = FALSE) - if (!any(names(object$spec$method) == "numeric")) + if (!any(names(object$spec$method$pred) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -22,17 +22,17 @@ predict_numeric.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$numeric$pre)) - new_data <- object$spec$method$numeric$pre(new_data, object) + if (!is.null(object$spec$method$pred$numeric$pre)) + new_data <- object$spec$method$pred$numeric$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$numeric) + pred_call <- make_pred_call(object$spec$method$pred$numeric) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$numeric$post)) { - res <- object$spec$method$numeric$post(res, object) + if (!is.null(object$spec$method$pred$numeric$post)) { + res <- object$spec$method$pred$numeric$post(res, object) } if (is.vector(res)) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 698ddb4c8..19d22654f 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -9,7 +9,7 @@ predict_quantile.model_fit <- function (object, new_data, quantile = (1:9)/10, ...) { - if (is.null(object$spec$method$quantile)) + if (is.null(object$spec$method$pred$quantile)) stop("No quantile prediction method defined for this ", "engine.", call. = FALSE) @@ -21,18 +21,18 @@ predict_quantile.model_fit <- new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$quantile$pre)) - new_data <- object$spec$method$quantile$pre(new_data, object) + if (!is.null(object$spec$method$pred$quantile$pre)) + new_data <- object$spec$method$pred$quantile$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$quantile$args$p <- quantile - pred_call <- make_pred_call(object$spec$method$quantile) + object$spec$method$pred$quantile$args$p <- quantile + pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$quantile$post)) { - res <- object$spec$method$quantile$post(res, object) + if(!is.null(object$spec$method$pred$quantile$post)) { + res <- object$spec$method$pred$quantile$post(res, object) } res diff --git a/R/predict_raw.R b/R/predict_raw.R index 315c9dd0a..2d859f8a7 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -4,17 +4,17 @@ # @export predict_raw.model_fit # @export predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { - protected_args <- names(object$spec$method$raw$args) + protected_args <- names(object$spec$method$pred$raw$args) dup_args <- names(opts) %in% protected_args if (any(dup_args)) { opts <- opts[[!dup_args]] } if (length(opts) > 0) { - object$spec$method$raw$args <- - c(object$spec$method$raw$args, opts) + object$spec$method$pred$raw$args <- + c(object$spec$method$pred$raw$args, opts) } - if (!any(names(object$spec$method) == "raw")) + if (!any(names(object$spec$method$pred) == "raw")) stop("No raw prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -25,11 +25,11 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$raw$pre)) - new_data <- object$spec$method$raw$pre(new_data, object) + if (!is.null(object$spec$method$pred$raw$pre)) + new_data <- object$spec$method$pred$raw$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$raw) + pred_call <- make_pred_call(object$spec$method$pred$raw) res <- eval_tidy(pred_call) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 65eb84864..77ddea12e 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -1,24 +1,3 @@ - -rand_forest_arg_key <- data.frame( - randomForest = c("mtry", "ntree", "nodesize"), - ranger = c("mtry", "num.trees", "min.node.size"), - spark = - c("feature_subset_strategy", "num_trees", "min_instances_per_node"), - stringsAsFactors = FALSE, - row.names = c("mtry", "trees", "min_n") -) - -rand_forest_modes <- c("classification", "regression", "unknown") - -rand_forest_engines <- data.frame( - ranger = c(TRUE, TRUE, FALSE), - randomForest = c(TRUE, TRUE, FALSE), - spark = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") -) - -# ------------------------------------------------------------------------------ - # wrappers for ranger ranger_class_pred <- function(results, object) { @@ -32,7 +11,7 @@ ranger_class_pred <- #' @importFrom stats qnorm ranger_num_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qnorm(hf_lvl, lower.tail = FALSE) res <- @@ -44,12 +23,12 @@ ranger_num_confint <- function(object, new_data, ...) { res$.pred_upper <- res$.pred + const * std_error res$.pred <- NULL - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res$.std_error <- std_error res } ranger_class_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qnorm(hf_lvl, lower.tail = FALSE) pred <- predict(object$fit, data = new_data, type = "response", ...)$predictions @@ -73,21 +52,21 @@ ranger_class_confint <- function(object, new_data, ...) { col_names <- paste0(c(".pred_lower_", ".pred_upper_"), lvl) res <- res[, col_names] - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res <- bind_cols(res, std_error) res } ranger_confint <- function(object, new_data, ...) { - if(object$fit$forest$treetype == "Regression") { + if (object$fit$forest$treetype == "Regression") { res <- ranger_num_confint(object, new_data, ...) } else { - if(object$fit$forest$treetype == "Probability estimation") { + if (object$fit$forest$treetype == "Probability estimation") { res <- ranger_class_confint(object, new_data, ...) } else { - stop ("Cannot compute confidence intervals for a ranger forest ", - "of type ", object$fit$forest$treetype, ".", call. = FALSE) + stop("Cannot compute confidence intervals for a ranger forest ", + "of type ", object$fit$forest$treetype, ".", call. = FALSE) } } res @@ -95,186 +74,454 @@ ranger_confint <- function(object, new_data, ...) { # ------------------------------------------------------------------------------ +set_new_model("rand_forest") -rand_forest_ranger_data <- - list( - libs = "ranger", - fit = list( - interface = "formula", - protect = c("formula", "data", "case.weights"), - func = c(pkg = "ranger", fun = "ranger"), - defaults = - list( - num.threads = 1, - verbose = FALSE, - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = function(results, object) results$predictions, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - class = list( - pre = NULL, - post = ranger_class_pred, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - classprob = list( - pre = function(x, object) { - if (object$fit$forest$treetype != "Probability estimation") - stop("`ranger` model does not appear to use class probabilities. Was ", - "the model fit with `probability = TRUE`?", - call. = FALSE) - x - }, - post = function(x, object) { - x <- x$prediction - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ), - confint = list( - pre = NULL, - post = NULL, - func = c(fun = "ranger_confint"), - args = - list( - object = quote(object), - new_data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ) +set_model_mode("rand_forest", "classification") +set_model_mode("rand_forest", "regression") + +# ------------------------------------------------------------------------------ +# ranger components + +set_model_engine("rand_forest", "classification", "ranger") +set_model_engine("rand_forest", "regression", "ranger") +set_dependency("rand_forest", "ranger", "ranger") + +set_model_arg( + model = "rand_forest", + eng = "ranger", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "ranger", + parsnip = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "ranger", + parsnip = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "ranger", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "case.weights"), + func = c(pkg = "ranger", fun = "ranger"), + defaults = + list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10 ^ 5, 1)) + ) ) +) -rand_forest_randomForest_data <- - list( - libs = "randomForest", - fit = list( - interface = "data.frame", - protect = c("x", "y"), - func = c(pkg = "randomForest", fun = "randomForest"), - defaults = - list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(as.data.frame(x)) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_fit( + model = "rand_forest", + eng = "ranger", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "case.weights"), + func = c(pkg = "ranger", fun = "ranger"), + defaults = + list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10 ^ 5, 1)) + ) ) +) +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = ranger_class_pred, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) -rand_forest_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "type"), - func = c(pkg = "sparklyr", fun = "ml_random_forest"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "classification", + type = "prob", + value = list( + pre = function(x, object) { + if (object$fit$forest$treetype != "Probability estimation") + stop( + "`ranger` model does not appear to use class probabilities. Was ", + "the model fit with `probability = TRUE`?", + call. = FALSE ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) + x + }, + post = function(x, object) { + x <- x$prediction + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "ranger_confint"), + args = + list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) +) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) + results$predictions, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "ranger_confint"), + args = + list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +# ------------------------------------------------------------------------------ +# randomForest components + +set_model_engine("rand_forest", "classification", "randomForest") +set_model_engine("rand_forest", "regression", "randomForest") +set_dependency("rand_forest", "randomForest", "randomForest") + +set_model_arg( + model = "rand_forest", + eng = "randomForest", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "randomForest", + parsnip = "trees", + original = "ntree", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "randomForest", + parsnip = "min_n", + original = "nodesize", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y"), + func = c(pkg = "randomForest", fun = "randomForest"), + defaults = + list() + ) +) + +set_fit( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + value = list( + interface = "data.frame", + protect = c("x", "y"), + func = c(pkg = "randomForest", fun = "randomForest"), + defaults = + list() + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(as.data.frame(x)) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ +# spark components + +set_model_engine("rand_forest", "classification", "spark") +set_model_engine("rand_forest", "regression", "spark") +set_dependency("rand_forest", "spark", "sparklyr") + +set_model_arg( + model = "rand_forest", + eng = "spark", + parsnip = "mtry", + original = "feature_subset_strategy", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "spark", + parsnip = "trees", + original = "num_trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "spark", + parsnip = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_random_forest"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + model = "rand_forest", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_random_forest"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + model = "rand_forest", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) diff --git a/R/surv_reg.R b/R/surv_reg.R index 33ab8b47e..878e20a9a 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -36,7 +36,7 @@ #' The model can be created using the `fit()` function using the #' following _engines_: #' \itemize{ -#' \item \pkg{R}: `"flexsurv"`, `"survreg"` (the default) +#' \item \pkg{R}: `"flexsurv"`, `"survival"` (the default) #' } #' #' @section Engine Details: @@ -49,9 +49,9 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} #' -#' \pkg{survreg} +#' \pkg{survival} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survival")} #' #' Note that `model = TRUE` is needed to produce quantile #' predictions when there is a stratification variable and can be @@ -143,8 +143,8 @@ update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) { #' @export translate.surv_reg <- function(x, engine = x$engine, ...) { if (is.null(engine)) { - message("Used `engine = 'survreg'` for translation.") - engine <- "survreg" + message("Used `engine = 'survival'` for translation.") + engine <- "survival" } x <- translate.default(x, engine, ...) x @@ -171,7 +171,7 @@ check_args.surv_reg <- function(object) { #' @importFrom stats setNames #' @importFrom dplyr mutate survreg_quant <- function(results, object) { - pctl <- object$spec$method$quantile$args$p + pctl <- object$spec$method$pred$quantile$args$p n <- nrow(results) p <- ncol(results) results <- @@ -208,8 +208,3 @@ flexsurv_quant <- function(results, object) { results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper")) } -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(".label") - diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 046b0a469..e52fdb2f9 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -1,120 +1,130 @@ -surv_reg_arg_key <- data.frame( - survreg = c("dist"), - flexsurv = c("dist"), - stringsAsFactors = FALSE, - row.names = c("dist") -) +set_new_model("surv_reg") +set_model_mode("surv_reg", "regression") -surv_reg_modes <- "regression" +# ------------------------------------------------------------------------------ -surv_reg_engines <- data.frame( - survreg = TRUE, - flexsurv = TRUE, - stringsAsFactors = TRUE, - row.names = c("regression") -) +set_model_engine("surv_reg", "regression", "flexsurv") +set_dependency("surv_reg", "flexsurv", "flexsurv") +set_dependency("surv_reg", "flexsurv", "survival") -# ------------------------------------------------------------------------------ +set_model_arg( + model = "surv_reg", + eng = "flexsurv", + parsnip = "dist", + original = "dist", + func = list(pkg = "dials", fun = "dist"), + has_submodel = FALSE +) -surv_reg_flexsurv_data <- - list( - libs = c("survival", "flexsurv"), - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "flexsurv", fun = "flexsurvreg"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = flexsurv_mean, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "mean" - ) - ), - quantile = list( - pre = NULL, - post = flexsurv_quant, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - quantiles = expr(quantile) - ) - ) +set_fit( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "flexsurv", fun = "flexsurvreg"), + defaults = list() ) +) -# ------------------------------------------------------------------------------ +set_pred( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = flexsurv_mean, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "mean" + ) + ) +) -surv_reg_survreg_data <- - list( - libs = c("survival"), - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "survival", fun = "survreg"), - defaults = list(model = TRUE) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) - ), - quantile = list( - pre = NULL, - post = survreg_quant, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - p = expr(quantile) - ) - ) +set_pred( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = flexsurv_quant, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + quantiles = expr(quantile) + ) ) +) # ------------------------------------------------------------------------------ -# surv_reg_stan_data <- -# list( -# libs = c("brms"), -# fit = list( -# interface = "formula", -# protect = c("formula", "data", "weights"), -# func = c(pkg = "brms", fun = "brm"), -# defaults = list( -# family = expr(brms::weibull()), -# seed = expr(sample.int(10^5, 1)) -# ) -# ), -# numeric = list( -# pre = NULL, -# post = function(results, object) { -# tibble::as_tibble(results) %>% -# dplyr::select(Estimate) %>% -# setNames(".pred") -# }, -# func = c(fun = "predict"), -# args = -# list( -# object = expr(object$fit), -# newdata = expr(new_data), -# type = "response" -# ) -# ) -# ) +set_model_engine("surv_reg", "regression", "survival") +set_dependency("surv_reg", "survival", "survival") +set_model_arg( + model = "surv_reg", + eng = "survival", + parsnip = "dist", + original = "dist", + func = list(pkg = "dials", fun = "dist"), + has_submodel = FALSE +) + +set_fit( + model = "surv_reg", + eng = "survival", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "survival", fun = "survreg"), + defaults = list(model = TRUE) + ) +) + +set_pred( + model = "surv_reg", + eng = "survival", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "surv_reg", + eng = "survival", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = survreg_quant, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + p = expr(quantile) + ) + ) +) diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index 04b0bc55e..565c52d87 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -1,69 +1,150 @@ -svm_poly_arg_key <- data.frame( - kernlab = c( "C", "degree", "scale", "epsilon"), - row.names = c("cost", "degree", "scale_factor", "margin"), - stringsAsFactors = FALSE +set_new_model("svm_poly") + +set_model_mode("svm_poly", "classification") +set_model_mode("svm_poly", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("svm_poly", "classification", "kernlab") +set_model_engine("svm_poly", "regression", "kernlab") +set_dependency("svm_poly", "kernlab", "kernlab") + +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "cost", + original = "C", + func = list(pkg = "dials", fun = "cost"), + has_submodel = FALSE ) -svm_poly_modes <- c("classification", "regression", "unknown") +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "degree", + original = "degree", + func = list(pkg = "dials", fun = "degree"), + has_submodel = FALSE +) -svm_poly_engines <- data.frame( - kernlab = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "scale_factor", + original = "scale", + func = list(pkg = "dials", fun = "scale_factor"), + has_submodel = FALSE +) +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "margin", + original = "epsilon", + func = list(pkg = "dials", fun = "margin"), + has_submodel = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "polydot") + ) +) -svm_poly_kernlab_data <- - list( - libs = "kernlab", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "kernlab", fun = "ksvm"), - defaults = list( - kernel = "polydot" +set_fit( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "polydot") + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" ) - ), - numeric = list( - pre = NULL, - post = svm_reg_post, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(result, object) as_tibble(result), - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) - ), - raw = list( - pre = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index fdb12727d..3c905e856 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -1,69 +1,142 @@ -svm_rbf_arg_key <- data.frame( - kernlab = c( "C", "sigma", "epsilon"), - row.names = c("cost", "rbf_sigma", "margin"), - stringsAsFactors = FALSE +set_new_model("svm_rbf") + +set_model_mode("svm_rbf", "classification") +set_model_mode("svm_rbf", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("svm_rbf", "classification", "kernlab") +set_model_engine("svm_rbf", "regression", "kernlab") +set_dependency("svm_rbf", "kernlab", "kernlab") + +set_model_arg( + model = "svm_rbf", + eng = "kernlab", + parsnip = "cost", + original = "C", + func = list(pkg = "dials", fun = "cost"), + has_submodel = FALSE ) -svm_rbf_modes <- c("classification", "regression", "unknown") +set_model_arg( + model = "svm_rbf", + eng = "kernlab", + parsnip = "rbf_sigma", + original = "sigma", + func = list(pkg = "dials", fun = "rbf_sigma"), + has_submodel = FALSE +) -svm_rbf_engines <- data.frame( - kernlab = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_model_arg( + model = "svm_rbf", + eng = "kernlab", + parsnip = "margin", + original = "epsilon", + func = list(pkg = "dials", fun = "margin"), + has_submodel = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "rbfdot") + ) +) + +set_fit( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "rbfdot") + ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) -svm_rbf_kernlab_data <- - list( - libs = "kernlab", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "kernlab", fun = "ksvm"), - defaults = list( - kernel = "rbfdot" +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" ) - ), - numeric = list( - pre = NULL, - post = svm_reg_post, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(result, object) as_tibble(result), - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) - ), - raw = list( - pre = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + diff --git a/R/translate.R b/R/translate.R index fd014f770..3522544cc 100644 --- a/R/translate.R +++ b/R/translate.R @@ -19,11 +19,11 @@ #' This function can be useful when you need to understand how #' `parsnip` goes from a generic model specific to a model fitting #' function. -#' +#' #' **Note**: this function is used internally and users should only use it #' to understand what the underlying syntax would be. It should not be used -#' to modify the model specification. -#' +#' to modify the model specification. +#' #' @examples #' lm_spec <- linear_reg(penalty = 0.01) #' @@ -41,7 +41,7 @@ #' #' @export -translate <- function (x, ...) +translate <- function(x, ...) UseMethod("translate") #' @importFrom utils getFromNamespace @@ -53,29 +53,34 @@ translate.default <- function(x, engine = x$engine, ...) { if (is.null(engine)) stop("Please set an engine.", call. = FALSE) + mod_name <- specific_model(x) + x$engine <- engine x <- check_engine(x) + if (x$mode == "unknown") { + stop("Model code depends on the mode; please specify one.", call. = FALSE) + } + if (is.null(x$method)) - x <- get_method(x, engine, ...) + x$method <- get_model_spec(mod_name, x$mode, engine) - arg_key <- get_module(specific_model(x)) + arg_key <- get_args(mod_name, engine) # deharmonize primary arguments - actual_args <- deharmonize(x$args, arg_key, x$engine) + actual_args <- deharmonize(x$args, arg_key) # check secondary arguments to see if they are in the final # expression unless there are dots, warn if protected args are # being altered - eng_arg_key <- arg_key[[x$engine]] - x$eng_args <- check_eng_args(x$eng_args, x$method$fit, eng_arg_key) + x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original) # keep only modified args - modifed_args <- !vapply(actual_args, null_value, lgl(1)) + modifed_args <- !purrr::map_lgl(actual_args, null_value) actual_args <- actual_args[modifed_args] # look for defaults if not modified in other - if(length(x$method$fit$defaults) > 0) { + if (length(x$method$fit$defaults) > 0) { in_other <- names(x$method$fit$defaults) %in% names(x$eng_args) x$defaults <- x$method$fit$defaults[!in_other] } @@ -89,40 +94,6 @@ translate.default <- function(x, engine = x$engine, ...) { x } -get_method <- function(x, engine = x$engine, ...) { - check_empty_ellipse(...) - x$engine <- engine - x <- check_engine(x) - x$method <- get_model_info(x, x$engine) - x -} - - -get_module <- function(nm) { - arg_key <- try( - getFromNamespace( - paste0(nm, "_arg_key"), - ns = "parsnip" - ), - silent = TRUE - ) - if(inherits(arg_key, "try-error")) { - arg_key <- try( - get(paste0(nm, "_arg_key")), - silent = TRUE - ) - } - if(inherits(arg_key, "try-error")) { - stop( - "Cannot find the model code: `", - paste0(nm, "_arg_key"), - "`", call. = FALSE - ) - } - arg_key -} - - #' @export print.model_spec <- function(x, ...) { cat("Model Specification (", x$mode, ")\n\n", sep = "") @@ -144,3 +115,62 @@ check_mode <- function(object, lvl) { } object } + +# ------------------------------------------------------------------------------ +# new code for revised model data structures + +get_model_spec <- function(model, mode, engine) { + m_env <- get_model_env() + env_obj <- rlang::env_names(m_env) + env_obj <- grep(model, env_obj, value = TRUE) + + res <- list() + res$libs <- + rlang::env_get(m_env, paste0(model, "_pkgs")) %>% + purrr::pluck("pkg") %>% + purrr::pluck(1) + + res$fit <- + rlang::env_get(m_env, paste0(model, "_fit")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::pull(value) %>% + purrr::pluck(1) + + pred_code <- + rlang::env_get(m_env, paste0(model, "_predict")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::select(-engine, -mode) + + res$pred <- pred_code[["value"]] + names(res$pred) <- pred_code$type + + res +} + +get_args <- function(model, engine) { + m_env <- get_model_env() + rlang::env_get(m_env, paste0(model, "_args")) %>% + dplyr::filter(engine == !!engine) %>% + dplyr::select(-engine) +} + +# to replace harmonize +deharmonize <- function(args, key) { + if (length(args) == 0) + return(args) + parsn <- tibble(parsnip = names(args), order = seq_along(args)) + merged <- + dplyr::left_join(parsn, key, by = "parsnip") %>% + dplyr::arrange(order) + # TODO correct for bad merge? + + names(args) <- merged$original + args[!is.na(merged$original)] +} + +add_methods <- function(x, engine) { + x$engine <- engine + x <- check_engine(x) + x$method <- get_model_spec(specific_model(x), x$mode, x$engine) + x +} diff --git a/R/zzz.R b/R/zzz.R index dbee4e8cc..5978630f9 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,59 +1,59 @@ -## nocov start - -data_obj <- ls(pattern = "_data$") -data_obj <- data_obj[data_obj != "prepare_data"] - -#' @importFrom purrr map_dfr -#' @importFrom tibble as_tibble -data_names <- - map_dfr( - data_obj, - function(x) { - module <- names(get(x)) - if (length(module) > 1) { - module <- table(module) - module <- as_tibble(module) - module$object <- x - module - } else - module <- NULL - module - } - ) - -if(any(data_names$n > 1)) { - print(data_names[data_names$n > 1,]) - stop("Some models have duplicate module names.") -} -rm(data_names) - -# ------------------------------------------------------------------------------ - -engine_objects <- ls(pattern = "_engines$") -engine_objects <- engine_objects[engine_objects != "possible_engines"] - -#' @importFrom utils stack -get_engine_info <- function(x) { - y <- x - y <- get(y) - z <- stack(y) - z$mode <- rownames(y) - z$model <- gsub("_engines$", "", x) - z$object <- x - z <- z[z$values,] - z <- z[z$mode != "unknown",] - z$values <- NULL - names(z)[1] <- "engine" - z$engine <- as.character(z$engine) - z -} - -engine_info <- - purrr::map_df( - parsnip:::engine_objects, - get_engine_info -) - -rm(engine_objects) - -## nocov end +#' ## nocov start +#' +#' data_obj <- ls(pattern = "_data$") +#' data_obj <- data_obj[data_obj != "prepare_data"] +#' +#' #' @importFrom purrr map_dfr +#' #' @importFrom tibble as_tibble +#' data_names <- +#' map_dfr( +#' data_obj, +#' function(x) { +#' module <- names(get(x)) +#' if (length(module) > 1) { +#' module <- table(module) +#' module <- as_tibble(module) +#' module$object <- x +#' module +#' } else +#' module <- NULL +#' module +#' } +#' ) +#' +#' if(any(data_names$n > 1)) { +#' print(data_names[data_names$n > 1,]) +#' stop("Some models have duplicate module names.") +#' } +#' rm(data_names) +#' +#' # ------------------------------------------------------------------------------ +#' +#' engine_objects <- ls(pattern = "_engines$") +#' engine_objects <- engine_objects[engine_objects != "possible_engines"] +#' +#' #' @importFrom utils stack +#' get_engine_info <- function(x) { +#' y <- x +#' y <- get(y) +#' z <- stack(y) +#' z$mode <- rownames(y) +#' z$model <- gsub("_engines$", "", x) +#' z$object <- x +#' z <- z[z$values,] +#' z <- z[z$mode != "unknown",] +#' z$values <- NULL +#' names(z)[1] <- "engine" +#' z$engine <- as.character(z$engine) +#' z +#' } +#' +#' engine_info <- +#' purrr::map_df( +#' parsnip:::engine_objects, +#' get_engine_info +#' ) +#' +#' rm(engine_objects) +#' +#' ## nocov end diff --git a/README.md b/README.md index a4253bf86..a095fe62d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,12 @@ -[![Travis build status](https://travis-ci.org/tidymodels/parsnip.svg?branch=master)](https://travis-ci.org/tidymodels/parsnip) -[![Coverage status](https://codecov.io/gh/tidymodels/parsnip/branch/master/graph/badge.svg)](https://codecov.io/github/tidymodels/parsnip?branch=master) -![](https://img.shields.io/badge/lifecycle-maturing-blue.svg) +[![Build +Status](https://travis-ci.org/tidymodels/parsnip.svg?branch=master)](https://travis-ci.org/tidymodels/parsnip) +[![Coverage +status](https://codecov.io/gh/tidymodels/parsnip/branch/master/graph/badge.svg)](https://codecov.io/github/tidymodels/parsnip?branch=master) +[![CRAN\_Status\_Badge](http://www.r-pkg.org/badges/version/parsnip)](http://cran.r-project.org/web/packages/parsnip) +[![Downloads](http://cranlogs.r-pkg.org/badges/parsnip)](http://cran.rstudio.com/package=parsnip) +[![lifecycle](https://img.shields.io/badge/lifecycle-maturing-blue.svg)](https://www.tidyverse.org/lifecycle/#maturing) + One issue with different functions available in R _that do the same thing_ is that they can have different interfaces and arguments. For example, to fit a random forest _classification_ model, we might have: diff --git a/_pkgdown.yml b/_pkgdown.yml index 0f4430bf4..644889998 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -53,7 +53,19 @@ reference: contents: - lending_club - wa_churn - - check_times + - check_times + + - title: Developer Tools + contents: + - set_new_model + - starts_with("set_model_") + - set_dependency + - set_fit + - set_pred + - show_model_info + - starts_with("get_") + + navbar: left: diff --git a/docs/dev/articles/articles/Classification.html b/docs/dev/articles/articles/Classification.html index f702a65ee..9f5ef1ac1 100644 --- a/docs/dev/articles/articles/Classification.html +++ b/docs/dev/articles/articles/Classification.html @@ -101,24 +101,24 @@

Classification Example

To demonstrate parsnip for classification models, the credit data will be used.

-
library(tidymodels)
+
library(tidymodels)
 #> Registered S3 method overwritten by 'xts':
 #>   method     from
 #>   as.zoo.xts zoo
-#> ── Attaching packages ───────────────────────────────── tidymodels 0.0.2 ──
+#> ── Attaching packages ──────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ──
 #> ✔ broom     0.5.1       ✔ purrr     0.3.2  
-#> ✔ dials     0.0.2       ✔ recipes   0.1.4  
+#> ✔ dials     0.0.2       ✔ recipes   0.1.5  
 #> ✔ dplyr     0.8.0.1     ✔ rsample   0.0.4  
 #> ✔ infer     0.4.0       ✔ yardstick 0.0.2
-#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ──
+#> ── Conflicts ─────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
 #> ✖ purrr::discard() masks scales::discard()
 #> ✖ dplyr::filter()  masks stats::filter()
 #> ✖ dplyr::lag()     masks stats::lag()
 #> ✖ recipes::step()  masks stats::step()
 
-data(credit_data)
+data(credit_data)
 
-set.seed(7075)
+set.seed(7075)
 data_split <- initial_split(credit_data, strata = "Status", p = 0.75)
 
 credit_train <- training(data_split)
@@ -136,41 +136,42 @@ 

Classification Example

test_normalized <- bake(credit_rec, new_data = credit_test, all_predictors())

keras will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model.

-
set.seed(57974)
+
+  set_mode("classification") %>% 
+  # Also set engine-specific arguments: 
+  set_engine("keras", verbose = 0, validation_split = .20) %>%
+  fit(Status ~ ., data = juice(credit_rec))
+
+nnet_fit
+#> parsnip model object
+#> 
+#> Model
+#> ___________________________________________________________________________
+#> Layer (type)                     Output Shape                  Param #     
+#> ===========================================================================
+#> dense_1 (Dense)                  (None, 5)                     115         
+#> ___________________________________________________________________________
+#> dense_2 (Dense)                  (None, 5)                     30          
+#> ___________________________________________________________________________
+#> dropout_1 (Dropout)              (None, 5)                     0           
+#> ___________________________________________________________________________
+#> dense_3 (Dense)                  (None, 2)                     12          
+#> ===========================================================================
+#> Total params: 157
+#> Trainable params: 157
+#> Non-trainable params: 0
+#> ___________________________________________________________________________

In parsnip, the predict function can be used:.

+#> bad 185 89 +#> good 128 711
+ + + +
+
+

parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

+
+ +
+

+ Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

+
+
+ + + + + + + diff --git a/docs/dev/reference/boost_tree.html b/docs/dev/reference/boost_tree.html index b5450dd08..cf07a7ff1 100644 --- a/docs/dev/reference/boost_tree.html +++ b/docs/dev/reference/boost_tree.html @@ -171,7 +171,7 @@

General Interface for Boosted Trees

time that the model is fit. Other options and argument can be set using the set_engine() function. If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

@@ -181,7 +181,7 @@

General Interface for Boosted Trees

loss_reduction = NULL, sample_size = NULL) # S3 method for boost_tree -update(object, mtry = NULL, trees = NULL, +update(object, mtry = NULL, trees = NULL, min_n = NULL, tree_depth = NULL, learn_rate = NULL, loss_reduction = NULL, sample_size = NULL, fresh = FALSE, ...) @@ -242,7 +242,7 @@

Arg ... -

Not used for update().

+

Not used for update().

@@ -255,9 +255,9 @@

Details

The data given to the function are not saved and are only used to determine the mode of the model. For boost_tree(), the possible modes are "regression" and "classification".

-

The model can be created using the fit() function using the +

The model can be created using the fit() function using the following engines:

    -
  • R: "xgboost", "C5.0"

  • +
  • R: "xgboost" (the default), "C5.0"

  • Spark: "spark"

@@ -265,13 +265,13 @@

Note

For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

@@ -309,7 +309,7 @@

See also

- +

Examples

@@ -328,12 +328,12 @@

Examp #> Main Arguments: #> mtry = 10 #> min_n = 3 -#>
update(model, mtry = 1)
#> Boosted Tree Model Specification (unknown) +#>
update(model, mtry = 1)
#> Boosted Tree Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 #> min_n = 3 -#>
update(model, mtry = 1, fresh = TRUE)
#> Boosted Tree Model Specification (unknown) +#>
update(model, mtry = 1, fresh = TRUE)
#> Boosted Tree Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 diff --git a/docs/dev/reference/check_mod_val.html b/docs/dev/reference/check_mod_val.html new file mode 100644 index 000000000..5de0f14f1 --- /dev/null +++ b/docs/dev/reference/check_mod_val.html @@ -0,0 +1,418 @@ + + + + + + + + +Tools to Register Models — pred_types • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + +
+ +
+
+ + +
+ +

These functions are similar to constructors and can be used to validate +that there are no conflicts with the underlying model structures used by the +package.

+ +
+ +
pred_types
+
+check_mod_val(model, new = FALSE, existence = FALSE)
+
+get_model_env()
+
+check_mode_val(mode)
+
+check_engine_val(eng)
+
+check_arg_val(arg)
+
+check_submodels_val(has_submodel)
+
+check_func_val(func)
+
+check_fit_info(fit_obj)
+
+check_pred_info(pred_obj, type)
+
+check_pkg_val(pkg)
+
+set_new_model(model)
+
+set_model_mode(model, mode)
+
+set_model_engine(model, mode, eng)
+
+set_model_arg(model, eng, parsnip, original, func, has_submodel)
+
+set_dependency(model, eng, pkg)
+
+get_dependency(model)
+
+set_fit(model, mode, eng, value)
+
+get_fit(model)
+
+set_pred(model, mode, eng, type, value)
+
+get_pred_type(model, type)
+
+show_model_info(model)
+
+get_from_env(items)
+ +

Arguments

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
model

A single character string for the model type (e.g. +"rand_forest", etc).

new

A single logical to check to see if the model that you are check +has not already been registered.

existence

A single logical to check to see if the model has already +been registered.

mode

A single character string for the model mode (e.g. "regression").

eng

A single character string for the model engine.

arg

A single character string for the model argument name.

has_submodel

A single logical for whether the argument +can make predictions on mutiple submodels at once.

func

A named character vector that describes how to call +a function. func should have elements pkg and fun. The +former is optional but is recommended and the latter is +required. For example, c(pkg = "stats", fun = "lm") would be +used to invoke the usual linear regression function. In some +cases, it is helpful to use c(fun = "predict") when using a +package's predict method.

fit_obj

A list with elements interface, protect, +func and defaults. See the package vignette "Making a +parsnip model from scratch".

pred_obj

A list with elements pre, post, func, and args. +See the package vignette "Making a parsnip model from scratch".

type

A single character value for the type of prediction. Possible +values are: +'raw', 'numeric', 'class', 'link', 'prob', 'conf_int', 'pred_int', 'quantile' +.

pkg

An options character string for a package name.

parsnip

A single character string for the "harmonized" argument name +that parsnip exposes.

original

A single character string for the argument name that +underlying model function uses.

value

A list that conforms to the fit_obj or pred_obj description +above, depending on context.

items

A character string of objects in the model environment.

+ +

Format

+ +

An object of class character of length 8.

+ +

Details

+ +

These functions are available for users to add their +own models or engines (in package or otherwise) so that they can +be accessed using parsnip. This are more thoroughly documented +on the package web site (see references below).

+

In short, parsnip stores an environment object that contains +all of the information and code about how models are used (e.g. +fitting, predicting, etc). These functions can be used to add +models to that environment as well as helper functions that can +be used to makes sure that the model data is in the right +format.

+ +

References

+ +

"Making a parsnip model from scratch" +https://tidymodels.github.io/parsnip/articles/articles/Scratch.html

+ + +

Examples

+
# Show the information about a model: +show_model_info("rand_forest")
#> Information for `rand_forest` +#> modes: unknown, classification, regression +#> +#> engines: +#> classification: randomForest, ranger, spark +#> regression: randomForest, ranger, spark +#> +#> arguments: +#> ranger: +#> mtry --> mtry +#> trees --> num.trees +#> min_n --> min.node.size +#> randomForest: +#> mtry --> mtry +#> trees --> ntree +#> min_n --> nodesize +#> spark: +#> mtry --> feature_subset_strategy +#> trees --> num_trees +#> min_n --> min_instances_per_node +#> +#> fit modules: +#> engine mode +#> ranger classification +#> ranger regression +#> randomForest classification +#> randomForest regression +#> spark classification +#> spark regression +#> +#> prediction modules: +#> mode engine methods +#> classification randomForest class, prob, raw +#> classification ranger class, conf_int, prob, raw +#> classification spark class, prob +#> regression randomForest numeric, raw +#> regression ranger conf_int, numeric, raw +#> regression spark numeric +#>
+# Access the model data: +current_code <- get_model_env() +ls(envir = current_code)
#> [1] "boost_tree" "boost_tree_args" +#> [3] "boost_tree_fit" "boost_tree_modes" +#> [5] "boost_tree_pkgs" "boost_tree_predict" +#> [7] "decision_tree" "decision_tree_args" +#> [9] "decision_tree_fit" "decision_tree_modes" +#> [11] "decision_tree_pkgs" "decision_tree_predict" +#> [13] "linear_reg" "linear_reg_args" +#> [15] "linear_reg_fit" "linear_reg_modes" +#> [17] "linear_reg_pkgs" "linear_reg_predict" +#> [19] "logistic_reg" "logistic_reg_args" +#> [21] "logistic_reg_fit" "logistic_reg_modes" +#> [23] "logistic_reg_pkgs" "logistic_reg_predict" +#> [25] "mars" "mars_args" +#> [27] "mars_fit" "mars_modes" +#> [29] "mars_pkgs" "mars_predict" +#> [31] "mlp" "mlp_args" +#> [33] "mlp_fit" "mlp_modes" +#> [35] "mlp_pkgs" "mlp_predict" +#> [37] "models" "modes" +#> [39] "multinom_reg" "multinom_reg_args" +#> [41] "multinom_reg_fit" "multinom_reg_modes" +#> [43] "multinom_reg_pkgs" "multinom_reg_predict" +#> [45] "nearest_neighbor" "nearest_neighbor_args" +#> [47] "nearest_neighbor_fit" "nearest_neighbor_modes" +#> [49] "nearest_neighbor_pkgs" "nearest_neighbor_predict" +#> [51] "null_model" "null_model_args" +#> [53] "null_model_fit" "null_model_modes" +#> [55] "null_model_pkgs" "null_model_predict" +#> [57] "rand_forest" "rand_forest_args" +#> [59] "rand_forest_fit" "rand_forest_modes" +#> [61] "rand_forest_pkgs" "rand_forest_predict" +#> [63] "surv_reg" "surv_reg_args" +#> [65] "surv_reg_fit" "surv_reg_modes" +#> [67] "surv_reg_pkgs" "surv_reg_predict" +#> [69] "svm_poly" "svm_poly_args" +#> [71] "svm_poly_fit" "svm_poly_modes" +#> [73] "svm_poly_pkgs" "svm_poly_predict" +#> [75] "svm_rbf" "svm_rbf_args" +#> [77] "svm_rbf_fit" "svm_rbf_modes" +#> [79] "svm_rbf_pkgs" "svm_rbf_predict"
+
+
+ +
+ +
+
+

parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

+
+ +
+

+ Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

+
+
+
+ + + + + + diff --git a/docs/dev/reference/check_times.html b/docs/dev/reference/check_times.html index 08fff9fe5..0b6c131d5 100644 --- a/docs/dev/reference/check_times.html +++ b/docs/dev/reference/check_times.html @@ -195,8 +195,8 @@

Details

Examples

-
data(check_times) -str(check_times)
#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: +
data(check_times) +str(check_times)
#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: #> $ package : chr "A3" "abbyyR" "abc" "abc.data" ... #> $ authors : int 1 1 1 1 5 3 2 1 4 6 ... #> $ imports : num 0 6 0 0 3 1 0 4 0 7 ... diff --git a/docs/dev/reference/decision_tree.html b/docs/dev/reference/decision_tree.html index bbb7cf34b..22825643e 100644 --- a/docs/dev/reference/decision_tree.html +++ b/docs/dev/reference/decision_tree.html @@ -159,7 +159,7 @@

General Interface for Decision Tree Models

time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

@@ -168,7 +168,7 @@

General Interface for Decision Tree Models

tree_depth = NULL, min_n = NULL) # S3 method for decision_tree -update(object, cost_complexity = NULL, +update(object, cost_complexity = NULL, tree_depth = NULL, min_n = NULL, fresh = FALSE, ...)

Arguments

@@ -205,15 +205,15 @@

Arg ... -

Not used for update().

+

Not used for update().

Details

-

The model can be created using the fit() function using the +

The model can be created using the fit() function using the following engines:

    -
  • R: "rpart" or "C5.0" (classification only)

  • +
  • R: "rpart" (the default) or "C5.0" (classification only)

  • Spark: "spark"

Note that, for rpart models, but cost_complexity and @@ -226,13 +226,13 @@

Note

For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

@@ -269,7 +269,7 @@

See also

- +

Examples

@@ -288,12 +288,12 @@

Examp #> Main Arguments: #> cost_complexity = 10 #> min_n = 3 -#>

update(model, cost_complexity = 1)
#> Decision Tree Model Specification (unknown) +#>
update(model, cost_complexity = 1)
#> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 #> min_n = 3 -#>
update(model, cost_complexity = 1, fresh = TRUE)
#> Decision Tree Model Specification (unknown) +#>
update(model, cost_complexity = 1, fresh = TRUE)
#> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 diff --git a/docs/dev/reference/descriptors.html b/docs/dev/reference/descriptors.html index cf0396989..22694b7e7 100644 --- a/docs/dev/reference/descriptors.html +++ b/docs/dev/reference/descriptors.html @@ -134,7 +134,7 @@

Data Set Characteristics Available when Fitting Models

-

When using the fit() functions there are some +

When using the fit() functions there are some variables that will be available for use in arguments. For example, if the user would like to choose an argument value based on the current number of rows in a data set, the .obs() @@ -174,7 +174,7 @@

Details
  • .y(): The known outcomes returned in the format given. Either a vector, matrix, or data frame.

  • .dat(): A data frame containing all of the predictors and the -outcomes. If fit_xy() was used, the outcomes are attached as the +outcomes. If fit_xy() was used, the outcomes are attached as the column, ..y.

  • For example, if you use the model formula Sepal.Width ~ . with the iris @@ -200,7 +200,7 @@

    Details

    To use these in a model fit, pass them to a model specification. The evaluation is delayed until the time when the -model is run via fit() (and the variables listed above are available). +model is run via fit() (and the variables listed above are available). For example:

         data("lending_club")
    diff --git a/docs/dev/reference/fit.html b/docs/dev/reference/fit.html
    index 59c76b4e1..8bd1bbf94 100644
    --- a/docs/dev/reference/fit.html
    +++ b/docs/dev/reference/fit.html
    @@ -132,18 +132,18 @@ 

    Fit a Model Specification to a Dataset

    -

    fit() and fit_xy() take a model specification, translate the required +

    fit() and fit_xy() take a model specification, translate the required code by substituting arguments, and execute the model fit routine.

    # S3 method for model_spec
    -fit(object, formula = NULL, data = NULL,
    +fit(object, formula = NULL, data = NULL,
       control = fit_control(), ...)
     
     # S3 method for model_spec
    -fit_xy(object, x = NULL, y = NULL,
    +fit_xy(object, x = NULL, y = NULL,
       control = fit_control(), ...)

    Arguments

    @@ -206,21 +206,24 @@

    Value

    Details

    -

    fit() and fit_xy() substitute the current arguments in the model +

    fit() and fit_xy() substitute the current arguments in the model specification into the computational engine's code, checks them for validity, then fits the model using the data and the engine-specific code. Different model functions have different interfaces (e.g. formula or x/y) and these functions translate -between the interface used when fit() or fit_xy() were invoked and the one +between the interface used when fit() or fit_xy() were invoked and the one required by the underlying model.

    When possible, these functions attempt to avoid making copies of the data. For example, if the underlying model uses a formula and -fit() is invoked, the original data are references +fit() is invoked, the original data are references when the model is fit. However, if the underlying model uses something else, such as x/y, the formula is evaluated and the data are converted to the required format. In this case, any calls in the resulting model objects reference the temporary objects used to fit the model.

    +

    If the model engine has not been set, the model's default engine will be used +(as discussed on each model page). If the verbosity option of +fit_control() is greater than zero, a warning will be produced.

    See also

    @@ -231,28 +234,26 @@

    Examp
    # Although `glm()` only has a formula interface, different # methods for specifying the model can be used -library(dplyr)
    #> +library(dplyr)
    #> #> Attaching package: ‘dplyr’
    #> The following object is masked from ‘package:testthat’: #> #> matches
    #> The following objects are masked from ‘package:stats’: #> #> filter, lag
    #> The following objects are masked from ‘package:base’: #> -#> intersect, setdiff, setequal, union
    data("lending_club") - -lr_mod <- logistic_reg() +#> intersect, setdiff, setequal, union
    data("lending_club") lr_mod <- logistic_reg() using_formula <- lr_mod %>% set_engine("glm") %>% - fit(Class ~ funded_amnt + int_rate, data = lending_club) + fit(Class ~ funded_amnt + int_rate, data = lending_club) using_xy <- lr_mod %>% set_engine("glm") %>% - fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], + fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], y = lending_club$Class) using_formula
    #> parsnip model object diff --git a/docs/dev/reference/fit_control.html b/docs/dev/reference/fit_control.html index aa73f18a4..d93ee2181 100644 --- a/docs/dev/reference/fit_control.html +++ b/docs/dev/reference/fit_control.html @@ -131,7 +131,7 @@

    Control the fit function

    -

    Options can be passed to the fit() function that control the output and +

    Options can be passed to the fit() function that control the output and computations

    @@ -154,7 +154,7 @@

    Arg catch

    A logical where a value of TRUE will evaluate -the model inside of try(, silent = TRUE). If the model fails, +the model inside of try(, silent = TRUE). If the model fails, an object is still returned (without an error) that inherits the class "try-error".

    diff --git a/docs/dev/reference/get_model_env.html b/docs/dev/reference/get_model_env.html new file mode 100644 index 000000000..0c2b8d9bb --- /dev/null +++ b/docs/dev/reference/get_model_env.html @@ -0,0 +1,191 @@ + + + + + + + + +Tools to Register Models — get_model_env • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + +
    + +
    +
    + + +
    + +

    Tools to Register Models

    + +
    + +
    get_model_env()
    +
    +set_new_model(model)
    +
    +set_model_mode(model, mode)
    +
    +set_model_engine(model, mode, eng)
    +
    +set_model_arg(model, eng, parsnip, original, func, has_submodel)
    +
    +set_dependency(model, eng, pkg)
    +
    +get_dependency(model)
    +
    +set_fit(model, mode, eng, value)
    +
    +get_fit(model)
    +
    +set_pred(model, mode, eng, type, value)
    +
    +get_pred_type(model, type)
    +
    +show_model_info(model)
    +
    +get_from_env(items)
    + + +
    + +
    + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    +
    +
    + + + + + + diff --git a/docs/dev/reference/index.html b/docs/dev/reference/index.html index 3252d3df9..0540feb14 100644 --- a/docs/dev/reference/index.html +++ b/docs/dev/reference/index.html @@ -226,6 +226,12 @@

    add_rowindex()

    + +

    Add a column of row numbers to a data frame

    + +

    .cols() .preds() .obs() .lvls() .facts() .x() .y() .dat()

    @@ -329,6 +335,20 @@

    <

    Execution Time Data

    + + + +

    Developer Tools

    +

    + + + + + +

    pred_types check_mod_val() get_model_env() check_mode_val() check_engine_val() check_arg_val() check_submodels_val() check_func_val() check_fit_info() check_pred_info() check_pkg_val() set_new_model() set_model_mode() set_model_engine() set_model_arg() set_dependency() get_dependency() set_fit() get_fit() set_pred() get_pred_type() show_model_info() get_from_env()

    + +

    Tools to Register Models

    +

    @@ -339,6 +359,7 @@

    Contents

  • Models
  • Infrastructure
  • Data
  • +
  • Developer Tools
  • diff --git a/docs/dev/reference/keras_mlp.html b/docs/dev/reference/keras_mlp.html index 51ffd154c..ed96c15ed 100644 --- a/docs/dev/reference/keras_mlp.html +++ b/docs/dev/reference/keras_mlp.html @@ -139,7 +139,7 @@

    Simple interface to MLP models via keras

    keras_mlp(x, y, hidden_units = 5, decay = 0, dropout = 0,
    -  epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3),
    +  epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3),
       ...)

    Arguments

    diff --git a/docs/dev/reference/lending_club.html b/docs/dev/reference/lending_club.html index 30bcfe282..bb62cc077 100644 --- a/docs/dev/reference/lending_club.html +++ b/docs/dev/reference/lending_club.html @@ -158,8 +158,8 @@

    Details

    Examples

    -
    data(lending_club) -str(lending_club)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: +
    data(lending_club) +str(lending_club)
    #> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: #> $ funded_amnt : int 16100 32000 10000 16800 3500 10000 11000 15000 6000 20000 ... #> $ term : Factor w/ 2 levels "term_36","term_60": 1 2 1 2 1 1 1 1 1 2 ... #> $ int_rate : num 13.99 11.99 16.29 13.67 7.39 ... diff --git a/docs/dev/reference/linear_reg.html b/docs/dev/reference/linear_reg.html index 35a3375f8..31dfc2fbf 100644 --- a/docs/dev/reference/linear_reg.html +++ b/docs/dev/reference/linear_reg.html @@ -155,7 +155,7 @@

    General Interface for Linear Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +163,7 @@

    General Interface for Linear Regression Models

    linear_reg(mode = "regression", penalty = NULL, mixture = NULL)
     
     # S3 method for linear_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,7 +200,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -209,9 +209,9 @@

    Details

    The data given to the function are not saved and are only used to determine the mode of the model. For linear_reg(), the mode will always be "regression".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "lm" or "glmnet"

    • +
    • R: "lm" (the default) or "glmnet"

    • Stan: "stan"

    • Spark: "spark"

    • keras: "keras"

    • @@ -221,13 +221,13 @@

      Note

      For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (via save()), the model$fit element of the parsnip +R session (via save()), the model$fit element of the parsnip object should be serialized via ml_save(object$fit) and separately saved to disk. In a new session, the object can be reloaded and reattached to the parsnip object.

      @@ -265,8 +265,8 @@

      See also

      - +

      Examples

      @@ -296,12 +296,12 @@

      Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>

    update(model, penalty = 1)
    #> Linear Regression Model Specification (regression) +#>
    update(model, penalty = 1)
    #> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
    update(model, penalty = 1, fresh = TRUE)
    #> Linear Regression Model Specification (regression) +#>
    update(model, penalty = 1, fresh = TRUE)
    #> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 diff --git a/docs/dev/reference/logistic_reg.html b/docs/dev/reference/logistic_reg.html index 7dde682bf..ca4332916 100644 --- a/docs/dev/reference/logistic_reg.html +++ b/docs/dev/reference/logistic_reg.html @@ -155,7 +155,7 @@

    General Interface for Logistic Regression Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -163,7 +163,7 @@

    General Interface for Logistic Regression Models

    logistic_reg(mode = "classification", penalty = NULL, mixture = NULL)
     
     # S3 method for logistic_reg
    -update(object, penalty = NULL, mixture = NULL,
    +update(object, penalty = NULL, mixture = NULL,
       fresh = FALSE, ...)

    Arguments

    @@ -200,16 +200,16 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    For logistic_reg(), the mode will always be "classification".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    Arguments

    @@ -206,15 +206,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    Engine Details

    @@ -239,7 +239,7 @@

    See also

    -

    varying(), fit()

    +

    varying(), fit()

    Examples

    @@ -253,12 +253,12 @@

    Examp #> Main Arguments: #> num_terms = 10 #> prune_method = none -#>
    update(model, num_terms = 1)
    #> MARS Model Specification (unknown) +#>
    update(model, num_terms = 1)
    #> MARS Model Specification (unknown) #> #> Main Arguments: #> num_terms = 1 #> prune_method = none -#>
    update(model, num_terms = 1, fresh = TRUE)
    #> MARS Model Specification (unknown) +#>
    update(model, num_terms = 1, fresh = TRUE)
    #> MARS Model Specification (unknown) #> #> Main Arguments: #> num_terms = 1 diff --git a/docs/dev/reference/mlp.html b/docs/dev/reference/mlp.html index 851e6f3d7..e9b746b82 100644 --- a/docs/dev/reference/mlp.html +++ b/docs/dev/reference/mlp.html @@ -170,7 +170,7 @@

    General Interface for Single Layer Neural Network

    dropout = NULL, epochs = NULL, activation = NULL) # S3 method for mlp -update(object, hidden_units = NULL, penalty = NULL, +update(object, hidden_units = NULL, penalty = NULL, dropout = NULL, epochs = NULL, activation = NULL, fresh = FALSE, ...) @@ -220,7 +220,7 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    @@ -230,15 +230,15 @@

    Details time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (see above), the values are taken from the underlying model -functions. One exception is hidden_units when nnet::nnet is used; that +functions. One exception is hidden_units when nnet::nnet is used; that function's size argument has no default so a value of 5 units will be used. Also, unless otherwise specified, the linout argument to -nnet::nnet() will be set to TRUE when a regression model is created. -If parameters need to be modified, update() can be used +nnet::nnet() will be set to TRUE when a regression model is created. +If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "nnet"

    • +
    • R: "nnet" (the default)

    • keras: "keras"

    An error is thrown if both penalty and dropout are specified for @@ -271,7 +271,7 @@

    See also

    - +

    Examples

    @@ -290,12 +290,12 @@

    Examp #> Main Arguments: #> hidden_units = 10 #> dropout = 0.3 -#>

    update(model, hidden_units = 2)
    #> Single Layer Neural Network Specification (unknown) +#>
    update(model, hidden_units = 2)
    #> Single Layer Neural Network Specification (unknown) #> #> Main Arguments: #> hidden_units = 2 #> dropout = 0.3 -#>
    update(model, hidden_units = 2, fresh = TRUE)
    #> Single Layer Neural Network Specification (unknown) +#>
    update(model, hidden_units = 2, fresh = TRUE)
    #> Single Layer Neural Network Specification (unknown) #> #> Main Arguments: #> hidden_units = 2 diff --git a/docs/dev/reference/model_fit.html b/docs/dev/reference/model_fit.html index 23afa17ec..a800662ad 100644 --- a/docs/dev/reference/model_fit.html +++ b/docs/dev/reference/model_fit.html @@ -168,7 +168,7 @@

    Examp # Keep the `x` matrix if the data are not too big. spec_obj <- linear_reg() %>% - set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) + set_engine("lm", x = ifelse(.obs() < 500, TRUE, FALSE)) spec_obj

    #> Linear Regression Model Specification (regression) #> #> Engine-Specific Arguments: @@ -176,7 +176,7 @@

    Examp #> #> Computational engine: lm #>

    -fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) +fit_obj <- fit(spec_obj, mpg ~ ., data = mtcars) fit_obj
    #> parsnip model object #> #> @@ -190,7 +190,7 @@

    Examp #> qsec vs am gear carb #> 0.82104 0.31776 2.52023 0.65541 -0.19942 #>

    -nrow(fit_obj$fit$x)
    #> [1] 32
    +nrow(fit_obj$fit$x)
    #> [1] 32
    update(model, penalty = 1)
    #> Multinomial Regression Model Specification (classification) +#>
    update(model, penalty = 1)
    #> Multinomial Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>
    update(model, penalty = 1, fresh = TRUE)
    #> Multinomial Regression Model Specification (classification) +#>
    update(model, penalty = 1, fresh = TRUE)
    #> Multinomial Regression Model Specification (classification) #> #> Main Arguments: #> penalty = 1 diff --git a/docs/dev/reference/nearest_neighbor.html b/docs/dev/reference/nearest_neighbor.html index 4b9d49f44..d22c15948 100644 --- a/docs/dev/reference/nearest_neighbor.html +++ b/docs/dev/reference/nearest_neighbor.html @@ -161,7 +161,7 @@

    General Interface for K-Nearest Neighbor Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -181,7 +181,8 @@

    Arg neighbors

    A single integer for the number of neighbors -to consider (often called k).

    +to consider (often called k). For kknn, a value of 5 +is used if neighbors is not specified.

    weight_func @@ -199,9 +200,9 @@

    Arg

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    Note

    @@ -221,12 +222,12 @@

    kknn (classification or regression)

     kknn::train.kknn(formula = missing_arg(), data = missing_arg(), 
    -    kmax = missing_arg())
    +    ks = 5)
     

    See also

    - +

    Examples

    diff --git a/docs/dev/reference/null_model.html b/docs/dev/reference/null_model.html index a7e171c4d..53eb934a3 100644 --- a/docs/dev/reference/null_model.html +++ b/docs/dev/reference/null_model.html @@ -153,7 +153,7 @@

    Arg

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    @@ -175,7 +175,7 @@

    See also

    -

    varying(), fit()

    +

    varying(), fit()

    Examples

    diff --git a/docs/dev/reference/nullmodel.html b/docs/dev/reference/nullmodel.html index b38efb183..ed7909a78 100644 --- a/docs/dev/reference/nullmodel.html +++ b/docs/dev/reference/nullmodel.html @@ -142,10 +142,10 @@

    Fit a simple, non-informative model

    nullmodel(x = NULL, y, ...) # S3 method for nullmodel -print(x, ...) +print(x, ...) # S3 method for nullmodel -predict(object, new_data = NULL, type = NULL, ...) +predict(object, new_data = NULL, type = NULL, ...)

    Arguments

    @@ -207,13 +207,13 @@

    Details

    Examples

    -outcome <- factor(sample(letters[1:2], +outcome <- factor(sample(letters[1:2], size = 100, - prob = c(.1, .9), + prob = c(.1, .9), replace = TRUE)) useless <- nullmodel(y = outcome) useless
    #> Null Regression Model -#> Predicted Value: b
    predict(useless, matrix(NA, nrow = 5))
    #> [1] b b b b b +#> Predicted Value: b
    predict(useless, matrix(NA, nrow = 5))
    #> [1] b b b b b #> Levels: a b
    diff --git a/docs/dev/reference/predict.model_fit.html b/docs/dev/reference/predict.model_fit.html index 6a6c0aa45..4c6df5f30 100644 --- a/docs/dev/reference/predict.model_fit.html +++ b/docs/dev/reference/predict.model_fit.html @@ -133,14 +133,14 @@

    Model predictions

    Apply a model to create different types of predictions. -predict() can be used for all types of models and used the +predict() can be used for all types of models and used the "type" argument for more specificity.

    # S3 method for model_fit
    -predict(object, new_data, type = NULL,
    -  opts = list(), ...)
    +predict(object, new_data, type=NULL, + opts=list(), ...)

    Arguments

    @@ -157,7 +157,7 @@

    Arg

    @@ -213,16 +213,16 @@

    Value

    appear in names and 2) vectors are never returned but type-specific prediction functions.

    When the model fit failed and the error was captured, the -predict() function will return the same structure as above but +predict() function will return the same structure as above but filled with missing values. This does not currently work for multivariate models.

    Details

    -

    If "type" is not supplied to predict(), then a choice +

    If "type" is not supplied to predict(), then a choice is made (type = "numeric" for regression models and type = "class" for classification).

    -

    predict() is designed to provide a tidy result (see "Value" +

    predict() is designed to provide a tidy result (see "Value" section below) in a tibble output format.

    When using type = "conf_int" and type = "pred_int", the options level and std_error can be used. The latter is a logical for an @@ -230,54 +230,54 @@

    Details

    Examples

    -
    library(dplyr) +
    library(dplyr) lm_model <- linear_reg() %>% set_engine("lm") %>% - fit(mpg ~ ., data = mtcars %>% slice(11:32)) + fit(mpg ~ ., data = mtcars %>% slice(11:32)) pred_cars <- mtcars %>% slice(1:10) %>% select(-mpg) -predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 +predict(lm_model, pred_cars)
    #> # A tibble: 10 x 1 #> .pred -#> <dbl> -#> 1 23.4 -#> 2 23.3 -#> 3 27.6 -#> 4 21.5 -#> 5 17.6 -#> 6 21.6 -#> 7 13.9 -#> 8 21.7 -#> 9 25.6 -#> 10 17.1
    -predict( +#> <dbl> +#> 1 23.4 +#> 2 23.3 +#> 3 27.6 +#> 4 21.5 +#> 5 17.6 +#> 6 21.6 +#> 7 13.9 +#> 8 21.7 +#> 9 25.6 +#> 10 17.1
    +predict( lm_model, pred_cars, type = "conf_int", level = 0.90 -)
    #> # A tibble: 10 x 2 +)
    #> # A tibble: 10 x 2 #> .pred_lower .pred_upper -#> <dbl> <dbl> -#> 1 17.9 29.0 -#> 2 18.1 28.5 -#> 3 24.0 31.3 -#> 4 17.5 25.6 -#> 5 14.3 20.8 -#> 6 17.0 26.2 -#> 7 9.65 18.2 -#> 8 16.2 27.2 -#> 9 14.2 37.0 -#> 10 11.5 22.7
    -predict( +#> <dbl> <dbl> +#> 1 17.9 29.0 +#> 2 18.1 28.5 +#> 3 24.0 31.3 +#> 4 17.5 25.6 +#> 5 14.3 20.8 +#> 6 17.0 26.2 +#> 7 9.65 18.2 +#> 8 16.2 27.2 +#> 9 14.2 37.0 +#> 10 11.5 22.7
    +predict( lm_model, pred_cars, type = "raw", - opts = list(type = "terms") + opts = list(type = "terms") )
    #> cyl disp hp drat wt qsec #> 1 -0.001433177 -0.8113275 0.6303467 -0.06120265 2.4139815 -1.567729 #> 2 -0.001433177 -0.8113275 0.6303467 -0.06120265 1.4488706 -0.736286 diff --git a/docs/dev/reference/rand_forest.html b/docs/dev/reference/rand_forest.html index 663b28935..1dea25281 100644 --- a/docs/dev/reference/rand_forest.html +++ b/docs/dev/reference/rand_forest.html @@ -157,7 +157,7 @@

    General Interface for Random Forest Models

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -166,7 +166,7 @@

    General Interface for Random Forest Models

    min_n = NULL) # S3 method for rand_forest -update(object, mtry = NULL, trees = NULL, +update(object, mtry = NULL, trees = NULL, min_n = NULL, fresh = FALSE, ...)

    Arguments

    @@ -204,15 +204,15 @@

    Arg

    - +
    type

    A single character value or NULL. Possible values are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", -or "raw". When NULL, predict() will choose an appropriate value +or "raw". When NULL, predict() will choose an appropriate value based on the model's mode.

    ...

    Not used for update().

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    @@ -220,7 +220,7 @@

    Note

    For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via fit() is available; using fit_xy() will +interface to via fit() is available; using fit_xy() will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor @@ -274,7 +274,7 @@

    See also

    -

    varying(), fit()

    +

    varying(), fit()

    Examples

    @@ -293,12 +293,12 @@

    Examp #> Main Arguments: #> mtry = 10 #> min_n = 3 -#>
    update(model, mtry = 1)
    #> Random Forest Model Specification (unknown) +#>
    update(model, mtry = 1)
    #> Random Forest Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 #> min_n = 3 -#>
    update(model, mtry = 1, fresh = TRUE)
    #> Random Forest Model Specification (unknown) +#>
    update(model, mtry = 1, fresh = TRUE)
    #> Random Forest Model Specification (unknown) #> #> Main Arguments: #> mtry = 1 diff --git a/docs/dev/reference/reexports.html b/docs/dev/reference/reexports.html index 699b80a7d..3f8af8092 100644 --- a/docs/dev/reference/reexports.html +++ b/docs/dev/reference/reexports.html @@ -139,7 +139,7 @@

    Objects exported from other packages

    These objects are imported from other packages. Follow the links below to see their documentation.

    -
    generics

    fit, fit_xy, varying_args

    +
    generics

    fit, fit_xy, varying_args

    magrittr

    %>%

    diff --git a/docs/dev/reference/surv_reg.html b/docs/dev/reference/surv_reg.html index 476112894..31b3d2a9a 100644 --- a/docs/dev/reference/surv_reg.html +++ b/docs/dev/reference/surv_reg.html @@ -159,7 +159,7 @@

    General Interface for Parametric Survival Models

    surv_reg(mode = "regression", dist = NULL)
     
     # S3 method for surv_reg
    -update(object, dist = NULL, fresh = FALSE, ...)
    +update(object, dist = NULL, fresh = FALSE, ...)

    Arguments

    @@ -185,7 +185,7 @@

    Arg

    - +
    ...

    Not used for update().

    Not used for update().

    @@ -195,15 +195,15 @@

    Details to determine the mode of the model. For surv_reg(),the mode will always be "regression".

    Since survival models typically involve censoring (and require the use of -survival::Surv() objects), the fit() function will require that the +survival::Surv() objects), the fit() function will require that the survival model be specified via the formula interface.

    Also, for the flexsurv::flexsurvfit engine, the typical strata function cannot be used. To achieve the same effect, the extra parameter roles can be used (as described above).

    For surv_reg(), the mode will always be "regression".

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

      -
    • R: "flexsurv", "survreg"

    • +
    • R: "flexsurv", "survival" (the default)

    Engine Details

    @@ -217,7 +217,7 @@

    survreg

    +

    survival

     survival::survreg(formula = missing_arg(), data = missing_arg(), 
         weights = missing_arg(), model = TRUE)
    @@ -233,7 +233,7 @@ 

    R

    See also

    - +

    Examples

    @@ -249,7 +249,7 @@

    Examp #> #> Main Arguments: #> dist = weibull -#>

    update(model, dist = "lnorm")
    #> Parametric Survival Regression Model Specification (regression) +#>
    update(model, dist = "lnorm")
    #> Parametric Survival Regression Model Specification (regression) #> #> Main Arguments: #> dist = lnorm diff --git a/docs/dev/reference/svm_poly.html b/docs/dev/reference/svm_poly.html index 4b5ea1519..3d59e9275 100644 --- a/docs/dev/reference/svm_poly.html +++ b/docs/dev/reference/svm_poly.html @@ -159,7 +159,7 @@

    General interface for polynomial support vector machines

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +168,7 @@

    General interface for polynomial support vector machines

    scale_factor = NULL, margin = NULL) # S3 method for svm_poly -update(object, cost = NULL, degree = NULL, +update(object, cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL, fresh = FALSE, ...)

    Arguments

    @@ -209,15 +209,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    Engine Details

    @@ -238,7 +238,7 @@

    See also

    -

    varying(), fit()

    +

    varying(), fit()

    Examples

    @@ -257,12 +257,12 @@

    Examp #> Main Arguments: #> cost = 10 #> scale_factor = 0.1 -#>
    update(model, cost = 1)
    #> Polynomial Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1)
    #> Polynomial Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 #> scale_factor = 0.1 -#>
    update(model, cost = 1, fresh = TRUE)
    #> Polynomial Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1, fresh = TRUE)
    #> Polynomial Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 diff --git a/docs/dev/reference/svm_rbf.html b/docs/dev/reference/svm_rbf.html index b1ae58eb4..7b4e682a4 100644 --- a/docs/dev/reference/svm_rbf.html +++ b/docs/dev/reference/svm_rbf.html @@ -159,7 +159,7 @@

    General interface for radial basis function support vector machines

    time that the model is fit. Other options and argument can be set using set_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified, update() can be used +functions. If parameters need to be modified, update() can be used in lieu of recreating the object from scratch.

    @@ -168,7 +168,7 @@

    General interface for radial basis function support vector machines

    margin = NULL) # S3 method for svm_rbf -update(object, cost = NULL, rbf_sigma = NULL, +update(object, cost = NULL, rbf_sigma = NULL, margin = NULL, fresh = FALSE, ...)

    Arguments

    @@ -205,15 +205,15 @@

    Arg ... -

    Not used for update().

    +

    Not used for update().

    Details

    -

    The model can be created using the fit() function using the +

    The model can be created using the fit() function using the following engines:

    Engine Details

    @@ -234,7 +234,7 @@

    See also

    -

    varying(), fit()

    +

    varying(), fit()

    Examples

    @@ -253,12 +253,12 @@

    Examp #> Main Arguments: #> cost = 10 #> rbf_sigma = 0.1 -#>
    update(model, cost = 1)
    #> Radial Basis Function Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1)
    #> Radial Basis Function Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 #> rbf_sigma = 0.1 -#>
    update(model, cost = 1, fresh = TRUE)
    #> Radial Basis Function Support Vector Machine Specification (unknown) +#>
    update(model, cost = 1, fresh = TRUE)
    #> Radial Basis Function Support Vector Machine Specification (unknown) #> #> Main Arguments: #> cost = 1 diff --git a/docs/dev/reference/translate.html b/docs/dev/reference/translate.html index 26329274a..0b00a588c 100644 --- a/docs/dev/reference/translate.html +++ b/docs/dev/reference/translate.html @@ -157,7 +157,7 @@

    Details

    translate() produces a template call that lacks the specific argument values (such as data, etc). These are filled in once -fit() is called with the specifics of the data for the model. +fit() is called with the specifics of the data for the model. The call may also include varying arguments if these are in the specification.

    It does contain the resolved argument names that are specific to diff --git a/docs/dev/reference/varying_args.model_spec.html b/docs/dev/reference/varying_args.model_spec.html index 22a240d23..526546ac9 100644 --- a/docs/dev/reference/varying_args.model_spec.html +++ b/docs/dev/reference/varying_args.model_spec.html @@ -132,20 +132,20 @@

    Determine varying arguments

    -

    varying_args() takes a model specification or a recipe and returns a tibble +

    varying_args() takes a model specification or a recipe and returns a tibble of information on all possible varying arguments and whether or not they are actually varying.

    # S3 method for model_spec
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)
     
     # S3 method for recipe
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)
     
     # S3 method for step
    -varying_args(object, full = TRUE, ...)
    +varying_args(object, full = TRUE, ...)

    Arguments

    @@ -182,50 +182,50 @@

    Details

    Examples

    # List all possible varying args for the random forest spec -rand_forest() %>% varying_args()
    #> # A tibble: 3 x 4 +rand_forest() %>% varying_args()
    #> # A tibble: 3 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec
    # mtry is now recognized as varying -rand_forest(mtry = varying()) %>% varying_args()
    #> # A tibble: 3 x 4 +rand_forest(mtry = varying()) %>% varying_args()
    #> # A tibble: 3 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry TRUE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry TRUE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec
    # Even engine specific arguments can vary rand_forest() %>% set_engine("ranger", sample.fraction = varying()) %>% - varying_args()
    #> # A tibble: 4 x 4 + varying_args()
    #> # A tibble: 4 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec -#> 4 sample.fraction TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec +#> 4 sample.fraction TRUE rand_forest model_spec
    # List only the arguments that actually vary rand_forest() %>% set_engine("ranger", sample.fraction = varying()) %>% - varying_args(full = FALSE)
    #> # A tibble: 1 x 4 + varying_args(full = FALSE)
    #> # A tibble: 1 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 sample.fraction TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 sample.fraction TRUE rand_forest model_spec
    rand_forest() %>% set_engine( "randomForest", strata = Class, sampsize = varying() ) %>% - varying_args()
    #> # A tibble: 5 x 4 + varying_args()
    #> # A tibble: 5 x 4 #> name varying id type -#> <chr> <lgl> <chr> <chr> -#> 1 mtry FALSE rand_forest model_spec -#> 2 trees FALSE rand_forest model_spec -#> 3 min_n FALSE rand_forest model_spec -#> 4 strata FALSE rand_forest model_spec -#> 5 sampsize TRUE rand_forest model_spec
    +#> <chr> <lgl> <chr> <chr> +#> 1 mtry FALSE rand_forest model_spec +#> 2 trees FALSE rand_forest model_spec +#> 3 min_n FALSE rand_forest model_spec +#> 4 strata FALSE rand_forest model_spec +#> 5 sampsize TRUE rand_forest model_spec