diff --git a/.travis.yml b/.travis.yml index f7f839773..8cd134b71 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,48 +4,27 @@ language: r dist: trusty sudo: true +cran: https://cran.rstudio.com + # until generics is finalized warnings_are_errors: false r: - 3.2 - - oldrel - - release + - 3.3 + - 3.4 + - 3.5 + - 3.6 - devel matrix: allow_failures: - - r: 3.2 - -r_binary_packages: - - RCurl - - dplyr - - glue - - magrittr - - stringi - - stringr - - munsell - - rlang - - reshape2 - - scales - - tibble - - ggplot2 - - Rcpp - - RcppEigen - - BH - - glmnet - - earth - - sparklyr - - flexsurv - - ranger - - randomforest - - xgboost - - C50 - + - r: 3.3 # inum install failure (seg fault) + - r: 3.2 # partykit install failure (libcoin needs >= 3.4.0) + - r: 3.4 # mvtnorm requires >= 3.5.0 cache: - packages: true directories: - $HOME/.keras - $HOME/.cache/pip @@ -69,12 +48,12 @@ before_script: before_install: - - sudo apt-get -y install libnlopt-dev + - sudo apt-get -y install libnlopt-dev - sudo apt-get update - sudo apt-get -y install python3 - mkdir -p ~/.R && echo "CXX14=g++-6" > ~/.R/Makevars - echo "CXX14FLAGS += -fPIC" >> ~/.R/Makevars - + after_success: - Rscript -e 'covr::codecov()' diff --git a/DESCRIPTION b/DESCRIPTION index 8de5c838c..4d305e962 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: parsnip -Version: 0.0.2 +Version: 0.0.2.9000 Title: A Common API to Modeling and Analysis Functions Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc). Authors@R: c( @@ -27,7 +27,8 @@ Imports: magrittr, stats, tidyr, - globals + globals, + vctrs Roxygen: list(markdown = TRUE) RoxygenNote: 6.1.1 Suggests: @@ -49,3 +50,4 @@ Suggests: rpart, MASS, nlme +Remotes: r-lib/vctrs diff --git a/NAMESPACE b/NAMESPACE index f24856ce0..a214f738d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -45,6 +45,7 @@ S3method(translate,decision_tree) S3method(translate,default) S3method(translate,mars) S3method(translate,mlp) +S3method(translate,nearest_neighbor) S3method(translate,rand_forest) S3method(translate,surv_reg) S3method(translate,svm_poly) @@ -85,6 +86,11 @@ export(fit.model_spec) export(fit_control) export(fit_xy) export(fit_xy.model_spec) +export(get_dependency) +export(get_fit) +export(get_from_env) +export(get_model_env) +export(get_pred_type) export(keras_mlp) export(linear_reg) export(logistic_reg) @@ -97,13 +103,24 @@ export(multinom_reg) export(nearest_neighbor) export(null_model) export(nullmodel) +export(pred_value_template) export(predict.model_fit) export(rand_forest) export(rpart_train) export(set_args) +export(set_dependency) export(set_engine) +export(set_env_val) +export(set_fit) +export(set_in_env) export(set_mode) +export(set_model_arg) +export(set_model_engine) +export(set_model_mode) +export(set_new_model) +export(set_pred) export(show_call) +export(show_model_info) export(surv_reg) export(svm_poly) export(svm_rbf) @@ -180,7 +197,6 @@ importFrom(stats,na.omit) importFrom(stats,na.pass) importFrom(stats,predict) importFrom(stats,qnorm) -importFrom(stats,qt) importFrom(stats,quantile) importFrom(stats,setNames) importFrom(stats,terms) @@ -194,4 +210,4 @@ importFrom(utils,capture.output) importFrom(utils,getFromNamespace) importFrom(utils,globalVariables) importFrom(utils,head) -importFrom(utils,stack) +importFrom(vctrs,vec_unique) diff --git a/NEWS.md b/NEWS.md index df598bdee..b135f4295 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,11 @@ # parsnip 0.0.2.9000 +## Breaking Changes + + * The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html). + * The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation). + * For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`. + ## New Features * `add_rowindex()` can create a column called `.row` to a data frame. diff --git a/R/aaa.R b/R/aaa.R index c8fd5a72e..230c2b08c 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -18,3 +18,10 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) { res } +# ------------------------------------------------------------------------------ + +#' @importFrom utils globalVariables +utils::globalVariables( + c('.', '.label', '.pred', '.row', 'data', 'engine', 'engine2', 'group', + 'lab', 'original', 'predicted_label', 'prediction', 'value', 'type') + ) diff --git a/R/aaa_models.R b/R/aaa_models.R new file mode 100644 index 000000000..9b8781078 --- /dev/null +++ b/R/aaa_models.R @@ -0,0 +1,765 @@ +# Initialize model environments + +# ------------------------------------------------------------------------------ + +## Rules about model-related information + +### Definitions: + +# - the model is the model type (e.g. "rand_forest", "linear_reg", etc) +# - the model's mode is the species of model such as "classification" or "regression" +# - the engines are within a model and mode and describe the method/implementation +# of the model in question. These are often R package names. + +### The package dependencies are model- and engine-specific. They are used across modes + +### The `fit` information is a list of data that is needed to fit the model. This +### information is specific to an engine and mode. + +### The `predict` information is also list of data that is needed to make some sort +### of prediction on the model object. The possible types are contained in `pred_types` +### and this information is specific to the engine, mode, and type (although there +### are no types across different modes). + +# ------------------------------------------------------------------------------ + + +parsnip <- rlang::new_environment() +parsnip$models <- NULL +parsnip$modes <- c("regression", "classification", "unknown") + +# ------------------------------------------------------------------------------ + +pred_types <- + c("raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile") + +# ------------------------------------------------------------------------------ + +#' Working with the parsnip model environment +#' +#' These functions read and write to the environment where the package stores +#' information about model specifications. +#' +#' @param items A character string of objects in the model environment. +#' @param ... Named values that will be assigned to the model environment. +#' @param name A single character value for a new symbol in the model environment. +#' @param value A single value for a new value in the model environment. +#' @keywords internal +#' @references "Making a parsnip model from scratch" +#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html} +#' @examples +#' # Access the model data: +#' current_code <- get_model_env() +#' ls(envir = current_code) +#' +#' @keywords internal +#' @export +get_model_env <- function() { + current <- utils::getFromNamespace("parsnip", ns = "parsnip") + current +} + +#' @rdname get_model_env +#' @keywords internal +#' @export +get_from_env <- function(items) { + mod_env <- get_model_env() + rlang::env_get(mod_env, items) +} + +#' @rdname get_model_env +#' @keywords internal +#' @export +set_in_env <- function(...) { + mod_env <- get_model_env() + rlang::env_bind(mod_env, ...) +} + +#' @rdname get_model_env +#' @keywords internal +#' @export +set_env_val <- function(name, value) { + if (length(name) != 1 || !is.character(name)) { + stop("`name` should be a single character value.", call. = FALSE) + } + mod_env <- get_model_env() + x <- list(value) + names(x) <- name + rlang::env_bind(mod_env, !!!x) +} + +# ------------------------------------------------------------------------------ + +check_eng_val <- function(eng) { + if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) + stop("Please supply a character string for an engine name (e.g. `'lm'`)", + call. = FALSE) + invisible(NULL) +} + + +check_model_exists <- function(model) { + if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) { + stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", + call. = FALSE) + } + + current <- get_model_env() + + if (!any(current$models == model)) { + stop("Model `", model, "` has not been registered.", call. = FALSE) + } + + invisible(NULL) +} + +check_model_doesnt_exist <- function(model) { + if (rlang::is_missing(model) || length(model) != 1 || !is.character(model)) { + stop("Please supply a character string for a model name (e.g. `'linear_reg'`)", + call. = FALSE) + } + + current <- get_model_env() + + if (any(current$models == model)) { + stop("Model `", model, "` already exists", call. = FALSE) + } + + invisible(NULL) +} + +check_mode_val <- function(mode) { + if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) + stop("Please supply a character string for a mode (e.g. `'regression'`)", + call. = FALSE) + invisible(NULL) +} + +check_engine_val <- function(eng) { + if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng)) + stop("Please supply a character string for an engine (e.g. `'lm'`)", + call. = FALSE) + invisible(NULL) +} + +check_arg_val <- function(arg) { + if (rlang::is_missing(arg) || length(arg) != 1 || !is.character(arg)) + stop("Please supply a character string for the argument", + call. = FALSE) + invisible(NULL) +} + +check_submodels_val <- function(has_submodel) { + if (!is.logical(has_submodel) || length(has_submodel) != 1) { + stop("The `submodels` argument should be a single logical.", call. = FALSE) + } + invisible(NULL) +} + +check_func_val <- function(func) { + msg <- + paste( + "`func` should be a named vector with element 'fun' and the optional ", + "element 'pkg'. These should both be single character strings." + ) + + if (rlang::is_missing(func) || !is.vector(func) || length(func) > 2) + stop(msg, call. = FALSE) + + nms <- sort(names(func)) + + if (all(is.null(nms))) { + stop(msg, call. = FALSE) + } + + if (length(func) == 1) { + if (isTRUE(any(nms != "fun"))) { + stop(msg, call. = FALSE) + } + } else { + if (!isTRUE(all.equal(nms, c("fun", "pkg")))) { + stop(msg, call. = FALSE) + } + } + + + if (!all(purrr::map_lgl(func, is.character))) { + stop(msg, call. = FALSE) + } + + invisible(NULL) +} + +check_fit_info <- function(fit_obj) { + if (is.null(fit_obj)) { + stop("The `fit` module cannot be NULL.", call. = FALSE) + } + exp_nms <- c("defaults", "func", "interface", "protect") + if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) { + stop("The `fit` module should have elements: ", + paste0("`", exp_nms, "`", collapse = ", "), + call. = FALSE) + } + + check_interface_val(fit_obj$interface) + check_func_val(fit_obj$func) + + if (!is.list(fit_obj$defaults)) { + stop("The `defaults` element should be a list: ", call. = FALSE) + } + + invisible(NULL) +} + +check_pred_info <- function(pred_obj, type) { + if (all(type != pred_types)) { + stop("The prediction type should be one of: ", + paste0("'", pred_types, "'", collapse = ", "), + call. = FALSE) + } + + exp_nms <- c("args", "func", "post", "pre") + if (!isTRUE(all.equal(sort(names(pred_obj)), exp_nms))) { + stop("The `predict` module should have elements: ", + paste0("`", exp_nms, "`", collapse = ", "), + call. = FALSE) + } + + if (!is.null(pred_obj$pre) & !is.function(pred_obj$pre)) { + stop("The `pre` module should be null or a function: ", + call. = FALSE) + } + if (!is.null(pred_obj$post) & !is.function(pred_obj$post)) { + stop("The `post` module should be null or a function: ", + call. = FALSE) + } + + check_func_val(pred_obj$func) + + if (!is.list(pred_obj$args)) { + stop("The `args` element should be a list. ", call. = FALSE) + } + + invisible(NULL) +} + +check_pkg_val <- function(pkg) { + if (rlang::is_missing(pkg) || length(pkg) != 1 || !is.character(pkg)) + stop("Please supply a single character vale for the package name", + call. = FALSE) + invisible(NULL) +} + +check_interface_val <- function(x) { + exp_interf <- c("data.frame", "formula", "matrix") + if (length(x) != 1 || !(x %in% exp_interf)) { + stop("The `interface` element should have a single value of : ", + paste0("`", exp_interf, "`", collapse = ", "), + call. = FALSE) + } + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ + +#' Tools to Register Models +#' +#' These functions are similar to constructors and can be used to validate +#' that there are no conflicts with the underlying model structures used by the +#' package. +#' +#' @param model A single character string for the model type (e.g. +#' `"rand_forest"`, etc). +#' @param mode A single character string for the model mode (e.g. "regression"). +#' @param eng A single character string for the model engine. +#' @param arg A single character string for the model argument name. +#' @param has_submodel A single logical for whether the argument +#' can make predictions on mutiple submodels at once. +#' @param func A named character vector that describes how to call +#' a function. `func` should have elements `pkg` and `fun`. The +#' former is optional but is recommended and the latter is +#' required. For example, `c(pkg = "stats", fun = "lm")` would be +#' used to invoke the usual linear regression function. In some +#' cases, it is helpful to use `c(fun = "predict")` when using a +#' package's `predict` method. +#' @param fit_obj A list with elements `interface`, `protect`, +#' `func` and `defaults`. See the package vignette "Making a +#' `parsnip` model from scratch". +#' @param pred_obj A list with elements `pre`, `post`, `func`, and `args`. +#' See the package vignette "Making a `parsnip` model from scratch". +#' @param type A single character value for the type of prediction. Possible +#' values are: `class`, `conf_int`, `numeric`, `pred_int`, `prob`, `quantile`, +#' and `raw`. +#' @param pkg An options character string for a package name. +#' @param parsnip A single character string for the "harmonized" argument name +#' that `parsnip` exposes. +#' @param original A single character string for the argument name that +#' underlying model function uses. +#' @param value A list that conforms to the `fit_obj` or `pred_obj` description +#' above, depending on context. +#' @param pre,post Optional functions for pre- and post-processing of prediction +#' results. +#' @param ... Optional arguments that should be passed into the `args` slot for +#' prediction objects. +#' @keywords internal +#' @details These functions are available for users to add their +#' own models or engines (in package or otherwise) so that they can +#' be accessed using `parsnip`. This are more thoroughly documented +#' on the package web site (see references below). +#' +#' In short, `parsnip` stores an environment object that contains +#' all of the information and code about how models are used (e.g. +#' fitting, predicting, etc). These functions can be used to add +#' models to that environment as well as helper functions that can +#' be used to makes sure that the model data is in the right +#' format. +#' +#' `check_model_exists()` checks the model value and ensures that the model has +#' already been registered. `check_model_doesnt_exist()` checks the model value +#' and also checks to see if it is novel in the environment. +#' +#' @references "Making a parsnip model from scratch" +#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html} +#' @examples +#' # set_new_model("shallow_learning_model") +#' +#' # Show the information about a model: +#' show_model_info("rand_forest") +#' @keywords internal +#' @export +set_new_model <- function(model) { + check_model_doesnt_exist(model) + + current <- get_model_env() + + set_env_val("models", c(current$models, model)) + set_env_val(model, dplyr::tibble(engine = character(0), mode = character(0))) + set_env_val( + paste0(model, "_pkgs"), + dplyr::tibble(engine = character(0), pkg = list()) + ) + set_env_val(paste0(model, "_modes"), "unknown") + set_env_val( + paste0(model, "_args"), + dplyr::tibble( + engine = character(0), + parsnip = character(0), + original = character(0), + func = list(), + has_submodel = logical(0) + ) + ) + set_env_val( + paste0(model, "_fit"), + dplyr::tibble( + engine = character(0), + mode = character(0), + value = list() + ) + ) + set_env_val( + paste0(model, "_predict"), + dplyr::tibble( + engine = character(0), + mode = character(0), + type = character(0), + value = list() + ) + ) + + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_model_mode <- function(model, mode) { + check_model_exists(model) + check_mode_val(mode) + + current <- get_model_env() + + if (!any(current$modes == mode)) { + current$modes <- unique(c(current$modes, mode)) + } + + set_env_val( + paste0(model, "_modes"), + unique(c(get_from_env(paste0(model, "_modes")), mode)) + ) + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_model_engine <- function(model, mode, eng) { + check_model_exists(model) + check_mode_val(mode) + check_eng_val(eng) + check_mode_val(eng) + + current <- get_model_env() + + new_eng <- dplyr::tibble(engine = eng, mode = mode) + old_eng <- get_from_env(model) + + engs <- + old_eng %>% + dplyr::bind_rows(new_eng) %>% + dplyr::distinct() + + set_env_val(model, engs) + set_model_mode(model, mode) + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ +#' @importFrom vctrs vec_unique +#' @rdname set_new_model +#' @keywords internal +#' @export +set_model_arg <- function(model, eng, parsnip, original, func, has_submodel) { + check_model_exists(model) + check_eng_val(eng) + check_arg_val(parsnip) + check_arg_val(original) + check_func_val(func) + check_submodels_val(has_submodel) + + current <- get_model_env() + old_args <- get_from_env(paste0(model, "_args")) + + new_arg <- + dplyr::tibble( + engine = eng, + parsnip = parsnip, + original = original, + func = list(func), + has_submodel = has_submodel + ) + + updated <- try(dplyr::bind_rows(old_args, new_arg), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new argument.", call. = FALSE) + } + + updated <- vctrs::vec_unique(updated) + set_env_val(paste0(model, "_args"), updated) + + invisible(NULL) +} + + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_dependency <- function(model, eng, pkg) { + check_model_exists(model) + check_eng_val(eng) + check_pkg_val(pkg) + + current <- get_model_env() + model_info <- get_from_env(model) + pkg_info <- get_from_env(paste0(model, "_pkgs")) + + has_engine <- + model_info %>% + dplyr::distinct(engine) %>% + dplyr::filter(engine == eng) %>% + nrow() + if (has_engine != 1) { + stop("The engine '", eng, "' has not been registered for model '", + model, "'. ", call. = FALSE) + } + + existing_pkgs <- + pkg_info %>% + dplyr::filter(engine == eng) + + if (nrow(existing_pkgs) == 0) { + pkg_info <- + pkg_info %>% + dplyr::bind_rows(tibble(engine = eng, pkg = list(pkg))) + + } else { + old_pkgs <- existing_pkgs + existing_pkgs$pkg[[1]] <- c(pkg, existing_pkgs$pkg[[1]]) + pkg_info <- + pkg_info %>% + dplyr::filter(engine != eng) %>% + dplyr::bind_rows(existing_pkgs) + } + + set_env_val(paste0(model, "_pkgs"), pkg_info) + + invisible(NULL) +} + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_dependency <- function(model) { + check_model_exists(model) + pkg_name <- paste0(model, "_pkgs") + if (!any(pkg_name != rlang::env_names(get_model_env()))) { + stop("`", model, "` does not have a dependency list in parsnip.", call. = FALSE) + } + rlang::env_get(get_model_env(), pkg_name) +} + + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_fit <- function(model, mode, eng, value) { + check_model_exists(model) + check_eng_val(eng) + check_mode_val(mode) + check_engine_val(eng) + check_fit_info(value) + + current <- get_model_env() + model_info <- get_from_env(model) + old_fits <- get_from_env(paste0(model, "_fit")) + + has_engine <- + model_info %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + if (has_engine != 1) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' has not been registered for model '", + model, "'. ", call. = FALSE) + } + + has_fit <- + old_fits %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + + if (has_fit > 0) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' already has a fit component for model '", + model, "'. ", call. = FALSE) + } + + new_fit <- + dplyr::tibble( + engine = eng, + mode = mode, + value = list(value) + ) + + updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new fit module", call. = FALSE) + } + + set_env_val( + paste0(model, "_fit"), + updated + ) + + invisible(NULL) +} + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_fit <- function(model) { + check_model_exists(model) + fit_name <- paste0(model, "_fit") + if (!any(fit_name != rlang::env_names(get_model_env()))) { + stop("`", model, "` does not have a `fit` method in parsnip.", call. = FALSE) + } + rlang::env_get(get_model_env(), fit_name) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +set_pred <- function(model, mode, eng, type, value) { + check_model_exists(model) + check_eng_val(eng) + check_mode_val(mode) + check_engine_val(eng) + check_pred_info(value, type) + + current <- get_model_env() + model_info <- get_from_env(model) + old_fits <- get_from_env(paste0(model, "_predict")) + + has_engine <- + model_info %>% + dplyr::filter(engine == eng & mode == !!mode) %>% + nrow() + if (has_engine != 1) { + stop("The combination of engine '", eng, "' and mode '", + mode, "' has not been registered for model '", + model, "'. ", call. = FALSE) + } + + has_pred <- + old_fits %>% + dplyr::filter(engine == eng & mode == !!mode & type == !!type) %>% + nrow() + if (has_pred > 0) { + stop("The combination of engine '", eng, "', mode '", + mode, "', and type '", type, + "' already has a prediction component for model '", + model, "'. ", call. = FALSE) + } + + new_fit <- + dplyr::tibble( + engine = eng, + mode = mode, + type = type, + value = list(value) + ) + + updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE) + if (inherits(updated, "try-error")) { + stop("An error occured when adding the new fit module", call. = FALSE) + } + + set_env_val(paste0(model, "_predict"), updated) + + invisible(NULL) +} + +#' @rdname set_new_model +#' @keywords internal +#' @export +get_pred_type <- function(model, type) { + check_model_exists(model) + pred_name <- paste0(model, "_predict") + if (!any(pred_name != rlang::env_names(get_model_env()))) { + stop("`", model, "` does not have any `pred` methods in parsnip.", call. = FALSE) + } + all_preds <- rlang::env_get(get_model_env(), pred_name) + if (!any(all_preds$type == type)) { + stop("`", model, "` does not have any `", type, + "` prediction methods in parsnip.", call. = FALSE) + } + dplyr::filter(all_preds, type == !!type) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +show_model_info <- function(model) { + check_model_exists(model) + current <- get_model_env() + + cat("Information for `", model, "`\n", sep = "") + + cat( + " modes:", + paste0(get_from_env(paste0(model, "_modes")), collapse = ", "), + "\n\n" + ) + + engines <- get_from_env(model) + if (nrow(engines) > 0) { + cat(" engines: \n") + engines %>% + dplyr::mutate( + mode = format(paste0(mode, ": ")) + ) %>% + dplyr::group_by(mode) %>% + dplyr::summarize( + engine = paste0(sort(engine), collapse = ", ") + ) %>% + dplyr::mutate( + lab = paste0(" ", mode, engine, "\n") + ) %>% + dplyr::ungroup() %>% + dplyr::pull(lab) %>% + cat(sep = "") + cat("\n") + } else { + cat(" no registered engines.\n\n") + } + + args <- get_from_env(paste0(model, "_args")) + if (nrow(args) > 0) { + cat(" arguments: \n") + args %>% + dplyr::select(engine, parsnip, original) %>% + dplyr::distinct() %>% + dplyr::mutate( + engine = format(paste0(" ", engine, ": ")), + parsnip = paste0(" ", format(parsnip), " --> ", original, "\n") + ) %>% + dplyr::group_by(engine) %>% + dplyr::mutate( + engine2 = ifelse(dplyr::row_number() == 1, engine, ""), + parsnip = ifelse(dplyr::row_number() == 1, paste0("\n", parsnip), parsnip), + lab = paste0(engine2, parsnip) + ) %>% + dplyr::ungroup() %>% + dplyr::pull(lab) %>% + cat(sep = "") + cat("\n") + } else { + cat(" no registered arguments.\n\n") + } + + fits <- get_from_env(paste0(model, "_fit")) + if (nrow(fits) > 0) { + cat(" fit modules:\n") + fits %>% + dplyr::select(-value) %>% + mutate(engine = paste0(" ", engine)) %>% + as.data.frame() %>% + print(row.names = FALSE) + cat("\n") + } else { + cat(" no registered fit modules.\n\n") + } + + preds <- get_from_env(paste0(model, "_predict")) + if (nrow(preds) > 0) { + cat(" prediction modules:\n") + preds %>% + dplyr::group_by(mode, engine) %>% + dplyr::summarize(methods = paste0(sort(type), collapse = ", ")) %>% + dplyr::ungroup() %>% + mutate(mode = paste0(" ", mode)) %>% + as.data.frame() %>% + print(row.names = FALSE) + cat("\n") + } else { + cat(" no registered prediction modules.\n\n") + } + invisible(NULL) +} + +# ------------------------------------------------------------------------------ + +#' @rdname set_new_model +#' @keywords internal +#' @export +pred_value_template <- function(pre = NULL, post = NULL, func, ...) { + if (rlang::is_missing(func)) { + stop("Please supply a value to `func`. See `?set_pred`.", call. = FALSE) + } + list(pre = pre, post = post, func = func, args = list(...)) +} + diff --git a/R/aaa_spark_helpers.R b/R/aaa_spark_helpers.R index 4d233a760..3eba67416 100644 --- a/R/aaa_spark_helpers.R +++ b/R/aaa_spark_helpers.R @@ -20,6 +20,3 @@ format_spark_num <- function(results, object) { results <- dplyr::rename(results, pred = prediction) results } - -#' @importFrom utils globalVariables -utils::globalVariables(c(".", "predicted_label", "prediction")) diff --git a/R/arguments.R b/R/arguments.R index 1ece4725e..39246d762 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -7,49 +7,6 @@ null_value <- function(x) { res } -deharmonize <- function(args, key, engine) { - nms <- names(args) - orig_names <- key[nms, engine] - names(args) <- orig_names - args[!is.na(orig_names)] -} - -parse_engine_options <- function(x) { - res <- ll() - if (length(x) >= 2) { # in case of NULL - - arg_names <- names(x[[2]]) - arg_names <- arg_names[arg_names != ""] - - if (length(arg_names) > 0) { - # in case of list() - res <- ll() - for (i in arg_names) { - res[[i]] <- x[[2]][[i]] - } # over arg_names - } # length == 0 - } - res -} - -prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) { - nms <- names(x) - if (length(whitelist) > 0) - nms <- nms[!(nms %in% whitelist)] - for (i in nms) { - if ( - is.null(x[[i]]) | - is_null(x[[i]]) | - !(i %in% modified) | - is_missing(x[[i]]) - ) - x[[i]] <- NULL - } - if(any(names(x) == "...")) - x["..."] <- NULL - x -} - check_eng_args <- function(args, obj, core_args) { # Make sure that we are not trying to modify an argument that # is explicitly protected in the method metadata or arg_key diff --git a/R/boost_tree.R b/R/boost_tree.R index 620648fa0..61c16c87c 100644 --- a/R/boost_tree.R +++ b/R/boost_tree.R @@ -261,6 +261,7 @@ check_args.boost_tree <- function(object) { #' @param subsample Subsampling proportion of rows. #' @param ... Other options to pass to `xgb.train`. #' @return A fitted `xgboost` object. +#' @keywords internal #' @export xgb_train <- function( x, y, @@ -386,15 +387,15 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { pred <- xgb_pred(object$fit, newdata = new_data, ntreelimit = tree) # switch based on prediction type - if(object$spec$mode == "regression") { + if (object$spec$mode == "regression") { pred <- tibble(.pred = pred) nms <- names(pred) } else { if (type == "class") { - pred <- boost_tree_xgboost_data$class$post(pred, object) + pred <- object$spec$method$pred$class$post(pred, object) pred <- tibble(.pred = factor(pred, levels = object$lvl)) } else { - pred <- boost_tree_xgboost_data$classprob$post(pred, object) + pred <- object$spec$method$pred$prob$post(pred, object) pred <- as_tibble(pred) names(pred) <- paste0(".pred_", names(pred)) } @@ -432,6 +433,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) { #' model in the printed output. #' @param ... Other arguments to pass. #' @return A fitted C5.0 model. +#' @keywords internal #' @export C5.0_train <- function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) { @@ -501,8 +503,3 @@ C50_by_tree <- function(tree, object, new_data, type, ...) { pred[, c(".row", "trees", nms)] } - -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c(".row")) diff --git a/R/boost_tree_data.R b/R/boost_tree_data.R index 559698511..98356f96b 100644 --- a/R/boost_tree_data.R +++ b/R/boost_tree_data.R @@ -1,175 +1,386 @@ +set_new_model("boost_tree") -boost_tree_arg_key <- data.frame( - xgboost = c("max_depth", "nrounds", "eta", "colsample_bytree", "min_child_weight", "gamma", "subsample"), - C5.0 = c( NA, "trials", NA, NA, "minCases", NA, "sample"), - spark = c("max_depth", "max_iter", "step_size", "feature_subset_strategy", "min_instances_per_node", "min_info_gain", "subsampling_rate"), - stringsAsFactors = FALSE, - row.names = c("tree_depth", "trees", "learn_rate", "mtry", "min_n", "loss_reduction", "sample_size") +set_model_mode("boost_tree", "classification") +set_model_mode("boost_tree", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("boost_tree", "classification", "xgboost") +set_model_engine("boost_tree", "regression", "xgboost") +set_dependency("boost_tree", "xgboost", "xgboost") + +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "trees", + original = "nrounds", + func = list(pkg = "dials", fun = "trees"), + has_submodel = TRUE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "learn_rate", + original = "eta", + func = list(pkg = "dials", fun = "learn_rate"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "mtry", + original = "colsample_bytree", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "min_n", + original = "min_child_weight", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "loss_reduction", + original = "gamma", + func = list(pkg = "dials", fun = "loss_reduction"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "xgboost", + parsnip = "sample_size", + original = "subsample", + func = list(pkg = "dials", fun = "sample_size"), + has_submodel = FALSE +) + +set_fit( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "xgb_train"), + defaults = list(nthread = 1, verbose = 0) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) ) -boost_tree_modes <- c("classification", "regression", "unknown") +set_fit( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "xgb_train"), + defaults = list(nthread = 1, verbose = 0) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + if (is.vector(x)) { + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + } else { + x <- object$lvl[apply(x, 1, which.max)] + } + x + }, + func = c(pkg = NULL, fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + if (is.vector(x)) { + x <- tibble(v1 = 1 - x, v2 = x) + } else { + x <- as_tibble(x) + } + colnames(x) <- object$lvl + x + }, + func = c(pkg = NULL, fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) -boost_tree_engines <- data.frame( - xgboost = rep(TRUE, 3), - C5.0 = c( TRUE, FALSE, TRUE), - spark = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_pred( + model = "boost_tree", + eng = "xgboost", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "xgb_pred"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) ) # ------------------------------------------------------------------------------ -boost_tree_xgboost_data <- - list( - libs = "xgboost", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "xgb_train"), - defaults = - list( - nthread = 1, - verbose = 0 - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - if (is.vector(x)) { - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - } else { - x <- object$lvl[apply(x, 1, which.max)] - } - x - }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - if (is.vector(x)) { - x <- tibble(v1 = 1 - x, v2 = x) - } else { - x <- as_tibble(x) - } - colnames(x) <- object$lvl - x - }, - func = c(pkg = NULL, fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "xgb_pred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -boost_tree_C5.0_data <- - list( - libs = "C50", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "parsnip", fun = "C5.0_train"), - defaults = list() - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = list( +set_model_engine("boost_tree", "classification", "C5.0") +set_dependency("boost_tree", "C5.0", "C50") + +set_model_arg( + model = "boost_tree", + eng = "C5.0", + parsnip = "trees", + original = "trials", + func = list(pkg = "dials", fun = "trees"), + has_submodel = TRUE +) +set_model_arg( + model = "boost_tree", + eng = "C5.0", + parsnip = "min_n", + original = "minCases", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "C5.0", + parsnip = "sample_size", + original = "sample", + func = list(pkg = "dials", fun = "sample_size"), + has_submodel = FALSE +) + +set_fit( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "parsnip", fun = "C5.0_train"), + defaults = list() + ) +) + +set_pred( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "prob" ) - ) - ) - - -boost_tree_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "type"), - func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) ) +) + +set_pred( + model = "boost_tree", + eng = "C5.0", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("boost_tree", "classification", "spark") +set_model_engine("boost_tree", "regression", "spark") +set_dependency("boost_tree", "spark", "sparklyr") + +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "trees", + original = "max_iter", + func = list(pkg = "dials", fun = "trees"), + has_submodel = TRUE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "learn_rate", + original = "step_size", + func = list(pkg = "dials", fun = "learn_rate"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "mtry", + original = "feature_subset_strategy", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "min_info_gain", + original = "gamma", + func = list(pkg = "dials", fun = "loss_reduction"), + has_submodel = FALSE +) +set_model_arg( + model = "boost_tree", + eng = "spark", + parsnip = "sample_size", + original = "subsampling_rate", + func = list(pkg = "dials", fun = "sample_size"), + has_submodel = FALSE +) + +set_fit( + model = "boost_tree", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + model = "boost_tree", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + model = "boost_tree", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "boost_tree", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = quote(object$fit), dataset = quote(new_data)) + ) +) diff --git a/R/decision_tree.R b/R/decision_tree.R index b767b965d..afa807405 100644 --- a/R/decision_tree.R +++ b/R/decision_tree.R @@ -252,6 +252,7 @@ check_args.decision_tree <- function(object) { #' those cases. #' @param ... Other arguments to pass to either `rpart` or `rpart.control`. #' @return A fitted rpart model. +#' @keywords internal #' @export rpart_train <- function(formula, data, weights = NULL, cp = 0.01, minsplit = 20, maxdepth = 30, ...) { diff --git a/R/decision_tree_data.R b/R/decision_tree_data.R index b4a599063..a8e8016e4 100644 --- a/R/decision_tree_data.R +++ b/R/decision_tree_data.R @@ -1,160 +1,297 @@ +set_new_model("decision_tree") -decision_tree_arg_key <- data.frame( - rpart = c( "maxdepth", "minsplit", "cp"), - C5.0 = c( NA, "minCases", NA), - spark = c("max_depth", "min_instances_per_node", NA), - stringsAsFactors = FALSE, - row.names = c("tree_depth", "min_n", "cost_complexity") +set_model_mode("decision_tree", "classification") +set_model_mode("decision_tree", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "rpart") +set_model_engine("decision_tree", "regression", "rpart") +set_dependency("decision_tree", "rpart", "rpart") + +set_model_arg( + model = "decision_tree", + eng = "rpart", + parsnip = "tree_depth", + original = "maxdepth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE ) -decision_tree_modes <- c("classification", "regression", "unknown") +set_model_arg( + model = "decision_tree", + eng = "rpart", + parsnip = "min_n", + original = "minsplit", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) -decision_tree_engines <- data.frame( - rpart = rep(TRUE, 3), - C5.0 = c(TRUE, FALSE, TRUE), - spark = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_model_arg( + model = "decision_tree", + eng = "rpart", + parsnip = "cost_complexity", + original = "cp", + func = list(pkg = "dials", fun = "cost_complexity"), + has_submodel = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + model = "decision_tree", + eng = "rpart", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rpart", fun = "rpart"), + defaults = list() + ) +) + +set_fit( + model = "decision_tree", + eng = "rpart", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rpart", fun = "rpart"), + defaults = list() + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) -decision_tree_rpart_data <- - list( - libs = "rpart", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rpart", fun = "rpart"), - defaults = - list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(pkg = NULL, fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) - ) - - -decision_tree_C5.0_data <- - list( - libs = "C50", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "parsnip", fun = "C5.0_train"), - defaults = list(trials = 1) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = list( +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = NULL, fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "class" ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = list( + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(pkg = NULL, fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "rpart", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "C5.0") +set_dependency("decision_tree", "C5.0", "C5.0") + +set_model_arg( + model = "decision_tree", + eng = "C5.0", + parsnip = "min_n", + original = "minCases", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "parsnip", fun = "C5.0_train"), + defaults = list(trials = 1) + ) +) + +set_pred( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + + +set_pred( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( object = quote(object$fit), - newdata = quote(new_data) + newdata = quote(new_data), + type = "prob" ) - ) - ) - - -decision_tree_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula"), - func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) ) +) + + +set_pred( + model = "decision_tree", + eng = "C5.0", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("decision_tree", "classification", "spark") +set_model_engine("decision_tree", "regression", "spark") +set_dependency("decision_tree", "spark", "spark") + +set_model_arg( + model = "decision_tree", + eng = "spark", + parsnip = "tree_depth", + original = "max_depth", + func = list(pkg = "dials", fun = "tree_depth"), + has_submodel = FALSE +) + +set_model_arg( + model = "decision_tree", + eng = "spark", + parsnip = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "decision_tree", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + defaults = + list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + model = "decision_tree", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula"), + func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"), + defaults = + list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + model = "decision_tree", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) + +set_pred( + model = "decision_tree", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(object = quote(object$fit), dataset = quote(new_data)) + ) +) diff --git a/R/engines.R b/R/engines.R index a1b32de5f..327f69dc6 100644 --- a/R/engines.R +++ b/R/engines.R @@ -1,23 +1,13 @@ -get_model_info <- function (x, engine) { - cls <- specific_model(x) - nm <- paste(cls, engine, "data", sep = "_") - res <- try(get(nm), silent = TRUE) - if (inherits(res, "try-error")) - stop("Can't find model object ", nm) - res -} - specific_model <- function(x) { cls <- class(x) cls[cls != "model_spec"] } - possible_engines <- function(object, ...) { - cls <- specific_model(object) - key_df <- get(paste(cls, "engines", sep = "_")) - colnames(key_df[object$mode, , drop = FALSE]) + m_env <- get_model_env() + engs <- rlang::env_get(m_env, specific_model(object)) + unique(engs$engine) } check_engine <- function(object) { diff --git a/R/fit.R b/R/fit.R index e5b745b8f..1cb32f300 100644 --- a/R/fit.R +++ b/R/fit.R @@ -52,8 +52,6 @@ #' #' lr_mod <- logistic_reg() #' -#' lr_mod <- logistic_reg() -#' #' using_formula <- #' lr_mod %>% #' set_engine("glm") %>% @@ -125,7 +123,7 @@ fit.model_spec <- ) # populate `method` with the details for this model type - object <- get_method(object, engine = object$engine) + object <- add_methods(object, engine = object$engine) check_installs(object) @@ -215,7 +213,7 @@ fit_xy.model_spec <- ) # populate `method` with the details for this model type - object <- get_method(object, engine = object$engine) + object <- add_methods(object, engine = object$engine) check_installs(object) @@ -239,7 +237,7 @@ fit_xy.model_spec <- ... ), - data.frame_data.frame =, matrix_data.frame = + data.frame_data.frame = , matrix_data.frame = xy_xy( object = object, env = eval_env, @@ -249,7 +247,7 @@ fit_xy.model_spec <- ), # heterogenous combinations - matrix_formula =, data.frame_formula = + matrix_formula = , data.frame_formula = xy_form( object = object, env = eval_env, @@ -367,7 +365,7 @@ check_xy_interface <- function(x, y, cl, model) { print.model_fit <- function(x, ...) { cat("parsnip model object\n\n") - if(inherits(x$fit, "try-error")) { + if (inherits(x$fit, "try-error")) { cat("Model fit failed with error:\n", x$fit, "\n") } else { print(x$fit, ...) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 74f614ede..940daebe7 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -181,9 +181,3 @@ xy_form <- function(object, env, control, ...) { res } -# ------------------------------------------------------------------------------ -## - - -#' @importFrom utils globalVariables -utils::globalVariables("data") diff --git a/R/linear_reg_data.R b/R/linear_reg_data.R index 6225d1023..54dbd3338 100644 --- a/R/linear_reg_data.R +++ b/R/linear_reg_data.R @@ -1,265 +1,369 @@ +set_new_model("linear_reg") -linear_reg_arg_key <- data.frame( - lm = c( NA, NA), - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - stan = c( NA, NA), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("linear_reg", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "lm") +set_dependency("linear_reg", "lm", "stats") + +set_fit( + model = "linear_reg", + eng = "lm", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "stats", fun = "lm"), + defaults = list() + ) ) -linear_reg_modes <- "regression" +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ) +) -linear_reg_engines <- data.frame( - lm = TRUE, - glmnet = TRUE, - spark = TRUE, - stan = TRUE, - keras = TRUE, - row.names = c("regression") +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + tibble::as_tibble(results) %>% + dplyr::select(-fit) %>% + setNames(c(".pred_lower", ".pred_upper")) + }, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "confidence", + level = expr(level), + type = "response" + ) + ) +) +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + tibble::as_tibble(results) %>% + dplyr::select(-fit) %>% + setNames(c(".pred_lower", ".pred_upper")) + }, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + interval = "prediction", + level = expr(level), + type = "response" + ) + ) ) +set_pred( + model = "linear_reg", + eng = "lm", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) # ------------------------------------------------------------------------------ -linear_reg_lm_data <- - list( - libs = "stats", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "stats", fun = "lm"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - tibble::as_tibble(results) %>% - dplyr::select(-fit) %>% - setNames(c(".pred_lower", ".pred_upper")) - }, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "confidence", - level = expr(level), - type = "response" - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - tibble::as_tibble(results) %>% - dplyr::select(-fit) %>% - setNames(c(".pred_lower", ".pred_upper")) - }, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - interval = "prediction", - level = expr(level), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ) +set_model_engine("linear_reg", "regression", "glmnet") +set_dependency("linear_reg", "glmnet", "glmnet") + +set_model_arg( + model = "linear_reg", + eng = "glmnet", + parsnip = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "linear_reg", + eng = "glmnet", + parsnip = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "gaussian") ) +) -# Note: For glmnet, you will need to make model-specific predict methods. -# See linear_reg.R -linear_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "gaussian" - ) - ), - numeric = list( - pre = NULL, - post = organize_glmnet_pred, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data)), - type = "response", - s = expr(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newx = expr(as.matrix(new_data)) - ) - ) +set_pred( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = organize_glmnet_pred, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newx = expr(as.matrix(new_data)), + type = "response", + s = expr(object$spec$args$penalty) + ) ) +) -linear_reg_stan_data <- - list( - libs = "rstanarm", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rstanarm", fun = "stan_glm"), - defaults = - list( - family = expr(stats::gaussian) - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - res <- - tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level, - lower = FALSE - ), - ) - if(object$spec$method$confint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_linpred"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - transform = TRUE, - seed = expr(sample.int(10^5, 1)) - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - res <- - tibble( - .pred_lower = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level - ), - .pred_upper = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level, - lower = FALSE - ), - ) - if(object$spec$method$predint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data) - ) - ) +set_pred( + model = "linear_reg", + eng = "glmnet", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = expr(object$fit), + newx = expr(as.matrix(new_data))) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "stan") +set_dependency("linear_reg", "stan", "rstanarm") + +set_fit( + model = "linear_reg", + eng = "stan", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rstanarm", fun = "stan_glm"), + defaults = list(family = expr(stats::gaussian)) ) +) -#' @importFrom dplyr select rename -linear_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_linear_regression") - ), - numeric = list( - pre = NULL, - post = function(results, object) { - results <- dplyr::rename(results, pred = prediction) - results <- dplyr::select(results, pred) - results - }, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = expr(object$fit), - dataset = expr(new_data) +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) + ) +) + +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res <- + tibble( + .pred_lower = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level + ), + .pred_upper = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level, + lower = FALSE + ), ) - ) + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_linpred"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + transform = TRUE, + seed = expr(sample.int(10^5, 1)) + ) ) +) -linear_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + res <- + tibble( + .pred_lower = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level + ), + .pred_upper = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level, + lower = FALSE + ), ) - ) + if (object$spec$method$pred$pred_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) + +set_pred( + model = "linear_reg", + eng = "stan", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = expr(object$fit), newdata = expr(new_data)) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("linear_reg", "regression", "spark") +set_dependency("linear_reg", "spark", "sparklyr") + +set_model_arg( + model = "linear_reg", + eng = "spark", + parsnip = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "linear_reg", + eng = "spark", + parsnip = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + + +set_fit( + model = "linear_reg", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_linear_regression"), + defaults = list() + ) +) + +set_pred( + model = "linear_reg", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) { + results <- dplyr::rename(results, pred = prediction) + results <- dplyr::select(results, pred) + results + }, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = list(x = expr(object$fit), dataset = expr(new_data)) + ) +) + +# ------------------------------------------------------------------------------ + + +set_model_engine("linear_reg", "regression", "keras") +set_dependency("linear_reg", "keras", "keras") +set_dependency("linear_reg", "keras", "magrittr") + +set_fit( + model = "linear_reg", + eng = "keras", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + model = "linear_reg", + eng = "keras", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = list(object = quote(object$fit), x = quote(as.matrix(new_data))) + ) +) diff --git a/R/logistic_reg.R b/R/logistic_reg.R index 566a0c4ad..1bc9062b6 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -355,8 +355,3 @@ predict_raw._lognet <- function (object, new_data, opts = list(), ...) { predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) } - -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c("group", ".pred")) diff --git a/R/logistic_reg_data.R b/R/logistic_reg_data.R index 6fa1e0c94..d04f0100e 100644 --- a/R/logistic_reg_data.R +++ b/R/logistic_reg_data.R @@ -1,354 +1,514 @@ +set_new_model("logistic_reg") -logistic_reg_arg_key <- data.frame( - glm = c( NA, NA), - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - stan = c( NA, NA), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("logistic_reg", "classification") + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "glm") +set_dependency("logistic_reg", "glm", "stats") + +set_fit( + model = "logistic_reg", + eng = "glm", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "stats", fun = "glm"), + defaults = list(family = expr(stats::binomial)) + ) ) -logistic_reg_modes <- "classification" +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = prob_to_class_2, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) -logistic_reg_engines <- data.frame( - glm = TRUE, - glmnet = TRUE, - spark = TRUE, - stan = TRUE, - keras = TRUE, - row.names = c("classification") +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) -#' @importFrom stats qt -logistic_reg_glm_data <- - list( - libs = "stats", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "stats", fun = "glm"), - defaults = - list( - family = expr(stats::binomial) - ) - ), - class = list( - pre = NULL, - post = prob_to_class_2, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" +set_pred( + model = "logistic_reg", + eng = "glm", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 + const <- + qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) + trans <- object$fit$family$linkinv + res_2 <- + tibble( + lo = trans(results$fit - const * results$se.fit), + hi = trans(results$fit + const * results$se.fit) ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 - const <- - qt(hf_lvl, df = object$fit$df.residual, lower.tail = FALSE) - trans <- object$fit$family$linkinv - res_2 <- - tibble( - lo = trans(results$fit - const * results$se.fit), - hi = trans(results$fit + const * results$se.fit) - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$confint$extras$std_error) - res$.std_error <- results$se.fit - res - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - se.fit = TRUE, - type = "link" - ) - ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- results$se.fit + res + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + se.fit = TRUE, + type = "link" + ) ) +) -# Note: For glmnet, you will need to make model-specific predict methods. -# See logistic_reg.R -logistic_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "binomial" - ) - ), - class = list( - pre = NULL, - post = organize_glmnet_class, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - classprob = list( - pre = NULL, - post = organize_glmnet_prob, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) - ) +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "glmnet") +set_dependency("logistic_reg", "glmnet", "glmnet") + +set_model_arg( + model = "logistic_reg", + eng = "glmnet", + parsnip = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "logistic_reg", + eng = "glmnet", + parsnip = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "binomial") ) +) -logistic_reg_stan_data <- - list( - libs = "rstanarm", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "rstanarm", fun = "stan_glm"), - defaults = - list( - family = expr(stats::binomial) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - x <- object$fit$family$linkinv(x) - x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) - unname(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- object$fit$family$linkinv(x) - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - confint = list( - pre = NULL, - post = function(results, object) { - res_2 <- - tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$confint$extras$level, - lower = FALSE - ), - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$confint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_linpred"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - transform = TRUE, - seed = expr(sample.int(10^5, 1)) - ) - ), - predint = list( - pre = NULL, - post = function(results, object) { - res_2 <- - tibble( - lo = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level - ), - hi = - convert_stan_interval( - results, - level = object$spec$method$predint$extras$level, - lower = FALSE - ), - ) - res_1 <- res_2 - res_1$lo <- 1 - res_2$hi - res_1$hi <- 1 - res_2$lo - res <- bind_cols(res_1, res_2) - lo_nms <- paste0(".pred_lower_", object$lvl) - hi_nms <- paste0(".pred_upper_", object$lvl) - colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) - - if (object$spec$method$predint$extras$std_error) - res$.std_error <- apply(results, 2, sd, na.rm = TRUE) - res - }, - func = c(pkg = "rstanarm", fun = "posterior_predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ) + +set_pred( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = organize_glmnet_class, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) ) +) +set_pred( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = organize_glmnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) + ) +) -logistic_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "binomial" - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) +set_pred( + model = "logistic_reg", + eng = "glmnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "spark") +set_dependency("logistic_reg", "spark", "sparklyr") + +set_model_arg( + model = "logistic_reg", + eng = "spark", + parsnip = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) -logistic_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) +set_model_arg( + model = "logistic_reg", + eng = "spark", + parsnip = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "logistic_reg", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), + defaults = + list( + family = "binomial" + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "keras") +set_dependency("logistic_reg", "keras", "keras") +set_dependency("logistic_reg", "keras", "magrittr") + +set_model_arg( + model = "logistic_reg", + eng = "keras", + parsnip = "decay", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) + +set_fit( + model = "logistic_reg", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + model = "logistic_reg", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("logistic_reg", "classification", "stan") +set_dependency("logistic_reg", "stan", "rstanarm") + +set_fit( + model = "logistic_reg", + eng = "stan", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "rstanarm", fun = "stan_glm"), + defaults = list(family = expr(stats::binomial)) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- object$fit$family$linkinv(x) + x <- ifelse(x >= 0.5, object$lvl[2], object$lvl[1]) + unname(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- object$fit$family$linkinv(x) + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = function(results, object) { + res_2 <- + tibble( + lo = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level + ), + hi = + convert_stan_interval( + results, + level = object$spec$method$pred$conf_int$extras$level, + lower = FALSE + ), ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$pred$conf_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_linpred"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + transform = TRUE, + seed = expr(sample.int(10^5, 1)) + ) + ) +) + +set_pred( + model = "logistic_reg", + eng = "stan", + mode = "classification", + type = "pred_int", + value = list( + pre = NULL, + post = function(results, object) { + res_2 <- + tibble( + lo = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level + ), + hi = + convert_stan_interval( + results, + level = object$spec$method$pred$pred_int$extras$level, + lower = FALSE + ), ) - ) + res_1 <- res_2 + res_1$lo <- 1 - res_2$hi + res_1$hi <- 1 - res_2$lo + res <- bind_cols(res_1, res_2) + lo_nms <- paste0(".pred_lower_", object$lvl) + hi_nms <- paste0(".pred_upper_", object$lvl) + colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2]) + + if (object$spec$method$pred$pred_int$extras$std_error) + res$.std_error <- apply(results, 2, sd, na.rm = TRUE) + res + }, + func = c(pkg = "rstanarm", fun = "posterior_predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) +) diff --git a/R/mars_data.R b/R/mars_data.R index 0c1076e68..a4a84e268 100644 --- a/R/mars_data.R +++ b/R/mars_data.R @@ -1,77 +1,152 @@ -mars_arg_key <- data.frame( - earth = c( "nprune", "degree", "pmethod"), - stringsAsFactors = FALSE, - row.names = c("num_terms", "prod_degree", "prune_method") +set_new_model("mars") + +set_model_mode("mars", "classification") +set_model_mode("mars", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("mars", "classification", "earth") +set_model_engine("mars", "regression", "earth") +set_dependency("mars", "earth", "earth") + +set_model_arg( + model = "mars", + eng = "earth", + parsnip = "num_terms", + original = "nprune", + func = list(pkg = "dials", fun = "num_terms"), + has_submodel = FALSE +) +set_model_arg( + model = "mars", + eng = "earth", + parsnip = "prod_degree", + original = "degree", + func = list(pkg = "dials", fun = "prod_degree"), + has_submodel = FALSE +) +set_model_arg( + model = "mars", + eng = "earth", + parsnip = "prune_method", + original = "pmethod", + func = list(pkg = "dials", fun = "prune_method"), + has_submodel = FALSE ) -mars_modes <- c("classification", "regression", "unknown") +set_fit( + model = "mars", + eng = "earth", + mode = "regression", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "earth", fun = "earth"), + defaults = list(keepxy = TRUE) + ) +) -mars_engines <- data.frame( - earth = rep(TRUE, 3), - row.names = c("classification", "regression", "unknown") +set_fit( + model = "mars", + eng = "earth", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y", "weights"), + func = c(pkg = "earth", fun = "earth"), + defaults = list(keepxy = TRUE) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "mars", + eng = "earth", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "mars", + eng = "earth", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) -mars_earth_data <- - list( - libs = "earth", - fit = list( - interface = "data.frame", - protect = c("x", "y", "weights"), - func = c(pkg = "earth", fun = "earth"), - defaults = list(keepxy = TRUE) - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - x <- ifelse(x[,1] >= 0.5, object$lvl[2], object$lvl[1]) - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- x[,1] - x <- tibble(v1 = 1 - x, v2 = x) - colnames(x) <- object$lvl - x - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_pred( + model = "mars", + eng = "earth", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + x <- ifelse(x[, 1] >= 0.5, object$lvl[2], object$lvl[1]) + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) ) +) +set_pred( + model = "mars", + eng = "earth", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- x[, 1] + x <- tibble(v1 = 1 - x, v2 = x) + colnames(x) <- object$lvl + x + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "mars", + eng = "earth", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) diff --git a/R/misc.R b/R/misc.R index 3885e7817..7dda7d0df 100644 --- a/R/misc.R +++ b/R/misc.R @@ -132,14 +132,6 @@ make_call <- function(fun, ns, args, ...) { out } -resolve_args <- function(args, ...) { - for (i in seq(along = args)) { - if (!is_missing_arg(args[[i]])) - args[[i]] <- eval_tidy(args[[i]], ...) - } - args -} - levels_from_formula <- function(f, dat) { if (inherits(dat, "tbl_spark")) res <- NULL @@ -152,8 +144,8 @@ is_spark <- function(x) isTRUE(unname(x$method$fit$func["pkg"] == "sparklyr")) -show_fit <- function(mod, eng) { - mod <- translate(x = mod, engine = eng) +show_fit <- function(model, eng) { + mod <- translate(x = model, engine = eng) fit_call <- show_call(mod) call_text <- deparse(fit_call) call_text <- paste0(call_text, collapse = "\n") @@ -201,7 +193,7 @@ update_dot_check <- function(...) { # ------------------------------------------------------------------------------ new_model_spec <- function(cls, args, eng_args, mode, method, engine) { - spec_modes <- get(paste0(cls, "_modes")) + spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes")) if (!(mode %in% spec_modes)) stop("`mode` should be one of: ", paste0("'", spec_modes, "'", collapse = ", "), diff --git a/R/mlp.R b/R/mlp.R index c8c94054e..3fe96631f 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -229,3 +229,189 @@ check_args.mlp <- function(object) { invisible(object) } + +# keras wrapper for feed-forward nnet + +class2ind <- function (x, drop2nd = FALSE) { + if (!is.factor(x)) + stop("`x` should be a factor") + y <- model.matrix( ~ x - 1) + colnames(y) <- gsub("^x", "", colnames(y)) + attributes(y)$assign <- NULL + attributes(y)$contrasts <- NULL + if (length(levels(x)) == 2 & drop2nd) { + y <- y[, 1] + } + y +} + + +# ------------------------------------------------------------------------------ + +#' Simple interface to MLP models via keras +#' +#' Instead of building a `keras` model sequentially, `keras_mlp` can be used to +#' create a feedforward network with a single hidden layer. Regularization is +#' via either weight decay or dropout. +#' +#' @param x A data frame or matrix of predictors +#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. +#' @param hidden_units An integer for the number of hidden units. +#' @param decay A non-negative real number for the amount of weight decay. Either +#' this parameter _or_ `dropout` can specified. +#' @param dropout The proportion of parameters to set to zero. Either +#' this parameter _or_ `decay` can specified. +#' @param epochs An integer for the number of passes through the data. +#' @param act A character string for the type of activation function between layers. +#' @param seeds A vector of three positive integers to control randomness of the +#' calculations. +#' @param ... Currently ignored. +#' @return A `keras` model object. +#' @keywords internal +#' @export +keras_mlp <- + function(x, y, + hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax", + seeds = sample.int(10^5, size = 3), + ...) { + + if(decay > 0 & dropout > 0) + stop("Please use either dropoput or weight decay.", call. = FALSE) + + if (!is.matrix(x)) + x <- as.matrix(x) + + if(is.character(y)) + y <- as.factor(y) + factor_y <- is.factor(y) + + if (factor_y) + y <- class2ind(y) + else { + if (isTRUE(ncol(y) > 1)) + y <- as.matrix(y) + else + y <- matrix(y, ncol = 1) + } + + model <- keras::keras_model_sequential() + if(decay > 0) { + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_regularizer = keras::regularizer_l2(decay), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) + } else { + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) + } + if(dropout > 0) + model %>% + keras::layer_dense( + units = hidden_units, + activation = act, + input_shape = ncol(x), + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) + ) %>% + keras::layer_dropout(rate = dropout, seed = seeds[2]) + + if (factor_y) + model <- model %>% + keras::layer_dense( + units = ncol(y), + activation = 'softmax', + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) + else + model <- model %>% + keras::layer_dense( + units = ncol(y), + activation = 'linear', + kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) + + arg_values <- parse_keras_args(...) + compile_call <- expr( + keras::compile(object = model) + ) + if(!any(names(arg_values$compile) == "loss")) + compile_call$loss <- + if(factor_y) "binary_crossentropy" else "mse" + if(!any(names(arg_values$compile) == "optimizer")) + compile_call$optimizer <- "adam" + for(arg in names(arg_values$compile)) + compile_call[[arg]] <- arg_values$compile[[arg]] + + model <- eval_tidy(compile_call) + + fit_call <- expr( + keras::fit(object = model) + ) + fit_call$x <- quote(x) + fit_call$y <- quote(y) + fit_call$epochs <- epochs + for(arg in names(arg_values$fit)) + fit_call[[arg]] <- arg_values$fit[[arg]] + + history <- eval_tidy(fit_call) + model + } + + +nnet_softmax <- function(results, object) { + if (ncol(results) == 1) + results <- cbind(1 - results, results) + + results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) + results <- as_tibble(t(results)) + names(results) <- paste0(".pred_", object$lvl) + results +} + +parse_keras_args <- function(...) { + exclusions <- c("object", "x", "y", "validation_data", "epochs") + fit_args <- c( + 'batch_size', + 'verbose', + 'callbacks', + 'view_metrics', + 'validation_split', + 'validation_data', + 'shuffle', + 'class_weight', + 'sample_weight', + 'initial_epoch', + 'steps_per_epoch', + 'validation_steps' + ) + compile_args <- c( + 'optimizer', + 'loss', + 'metrics', + 'loss_weights', + 'sample_weight_mode', + 'weighted_metrics', + 'target_tensors' + ) + dots <- list(...) + dots <- dots[!(names(dots) %in% exclusions)] + + list( + fit = dots[names(dots) %in% fit_args], + compile = dots[names(dots) %in% compile_args] + ) +} + +mlp_num_weights <- function(p, hidden_units, classes) { + ((p+1) * hidden_units) + ((hidden_units+1) * classes) +} + + diff --git a/R/mlp_data.R b/R/mlp_data.R index 724035e7e..4df47377a 100644 --- a/R/mlp_data.R +++ b/R/mlp_data.R @@ -1,312 +1,313 @@ -mlp_arg_key <- data.frame( - nnet = c("size", "decay", NA_character_, "maxit", NA_character_), - keras = c("hidden_units", "penalty", "dropout", "epochs", "activation"), - stringsAsFactors = FALSE, - row.names = c("hidden_units", "penalty", "dropout", "epochs", "activation") -) - -mlp_modes <- c("classification", "regression", "unknown") +set_new_model("mlp") -mlp_engines <- data.frame( - nnet = c(TRUE, TRUE, FALSE), - keras = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") -) +set_model_mode("mlp", "classification") +set_model_mode("mlp", "regression") # ------------------------------------------------------------------------------ -mlp_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ) - ) +set_model_engine("mlp", "classification", "keras") +set_model_engine("mlp", "regression", "keras") +set_dependency("mlp", "keras", "keras") +set_dependency("mlp", "keras", "magrittr") +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "hidden_units", + original = "hidden_units", + func = list(pkg = "dials", fun = "hidden_units"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "penalty", + original = "penalty", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "dropout", + original = "dropout", + func = list(pkg = "dials", fun = "dropout"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "epochs", + original = "epochs", + func = list(pkg = "dials", fun = "epochs"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "keras", + parsnip = "activation", + original = "activation", + func = list(pkg = "dials", fun = "activation"), + has_submodel = FALSE +) -nnet_softmax <- function(results, object) { - if (ncol(results) == 1) - results <- cbind(1 - results, results) - - results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) - results <- as_tibble(t(results)) - names(results) <- paste0(".pred_", object$lvl) - results -} -mlp_nnet_data <- - list( - libs = "nnet", - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "nnet", fun = "nnet"), - defaults = list(trace = FALSE) - ), - numeric = list( - pre = NULL, - post = maybe_multivariate, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = nnet_softmax, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_fit( + model = "mlp", + eng = "keras", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list() ) +) +set_fit( + model = "mlp", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list() + ) +) -# ------------------------------------------------------------------------------ - -# keras wrapper for feed-forward nnet - -class2ind <- function (x, drop2nd = FALSE) { - if (!is.factor(x)) - stop("`x` should be a factor") - y <- model.matrix( ~ x - 1) - colnames(y) <- gsub("^x", "", colnames(y)) - attributes(y)$assign <- NULL - attributes(y)$contrasts <- NULL - if (length(levels(x)) == 2 & drop2nd) { - y <- y[, 1] - } - y -} - - -#' Simple interface to MLP models via keras -#' -#' Instead of building a `keras` model sequentially, `keras_mlp` can be used to -#' create a feedforward network with a single hidden layer. Regularization is -#' via either weight decay or dropout. -#' -#' @param x A data frame or matrix of predictors -#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data. -#' @param hidden_units An integer for the number of hidden units. -#' @param decay A non-negative real number for the amount of weight decay. Either -#' this parameter _or_ `dropout` can specified. -#' @param dropout The proportion of parameters to set to zero. Either -#' this parameter _or_ `decay` can specified. -#' @param epochs An integer for the number of passes through the data. -#' @param act A character string for the type of activation function between layers. -#' @param seeds A vector of three positive integers to control randomness of the -#' calculations. -#' @param ... Currently ignored. -#' @return A `keras` model object. -#' @export -keras_mlp <- - function(x, y, - hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax", - seeds = sample.int(10^5, size = 3), - ...) { - - if(decay > 0 & dropout > 0) - stop("Please use either dropoput or weight decay.", call. = FALSE) - - if (!is.matrix(x)) - x <- as.matrix(x) - - if(is.character(y)) - y <- as.factor(y) - factor_y <- is.factor(y) +set_pred( + model = "mlp", + eng = "keras", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) - if (factor_y) - y <- class2ind(y) - else { - if (isTRUE(ncol(y) > 1)) - y <- as.matrix(y) - else - y <- matrix(y, ncol = 1) - } +set_pred( + model = "mlp", + eng = "keras", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) - model <- keras::keras_model_sequential() - if(decay > 0) { - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_regularizer = keras::regularizer_l2(decay), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) - } else { - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) - } - if(dropout > 0) - model %>% - keras::layer_dense( - units = hidden_units, - activation = act, - input_shape = ncol(x), - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1]) - ) %>% - keras::layer_dropout(rate = dropout, seed = seeds[2]) +) - if (factor_y) - model <- model %>% - keras::layer_dense( - units = ncol(y), - activation = 'softmax', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) +set_pred( + model = "mlp", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) ) - else - model <- model %>% - keras::layer_dense( - units = ncol(y), - activation = 'linear', - kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3]) + ) +) + +set_pred( + model = "mlp", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) ) + ) +) - arg_values <- parse_keras_args(...) - compile_call <- expr( - keras::compile(object = model) - ) - if(!any(names(arg_values$compile) == "loss")) - compile_call$loss <- - if(factor_y) "binary_crossentropy" else "mse" - if(!any(names(arg_values$compile) == "optimizer")) - compile_call$optimizer <- "adam" - for(arg in names(arg_values$compile)) - compile_call[[arg]] <- arg_values$compile[[arg]] +set_pred( + model = "mlp", + eng = "keras", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + x = quote(as.matrix(new_data)) + ) + ) +) - model <- eval_tidy(compile_call) +# ------------------------------------------------------------------------------ - fit_call <- expr( - keras::fit(object = model) - ) - fit_call$x <- quote(x) - fit_call$y <- quote(y) - fit_call$epochs <- epochs - for(arg in names(arg_values$fit)) - fit_call[[arg]] <- arg_values$fit[[arg]] +set_model_engine("mlp", "classification", "nnet") +set_model_engine("mlp", "regression", "nnet") +set_dependency("mlp", "nnet", "nnet") - history <- eval_tidy(fit_call) - model - } +set_model_arg( + model = "mlp", + eng = "nnet", + parsnip = "hidden_units", + original = "size", + func = list(pkg = "dials", fun = "hidden_units"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "nnet", + parsnip = "penalty", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) +set_model_arg( + model = "mlp", + eng = "nnet", + parsnip = "epochs", + original = "maxit", + func = list(pkg = "dials", fun = "epochs"), + has_submodel = FALSE +) +set_fit( + model = "mlp", + eng = "nnet", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "nnet"), + defaults = list(trace = FALSE) + ) +) -parse_keras_args <- function(...) { - exclusions <- c("object", "x", "y", "validation_data", "epochs") - fit_args <- c( - 'batch_size', - 'verbose', - 'callbacks', - 'view_metrics', - 'validation_split', - 'validation_data', - 'shuffle', - 'class_weight', - 'sample_weight', - 'initial_epoch', - 'steps_per_epoch', - 'validation_steps' +set_fit( + model = "mlp", + eng = "nnet", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "nnet"), + defaults = list(trace = FALSE) ) - compile_args <- c( - 'optimizer', - 'loss', - 'metrics', - 'loss_weights', - 'sample_weight_mode', - 'weighted_metrics', - 'target_tensors' +) + +set_pred( + model = "mlp", + eng = "nnet", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = maybe_multivariate, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) ) - dots <- list(...) - dots <- dots[!(names(dots) %in% exclusions)] +) - list( - fit = dots[names(dots) %in% fit_args], - compile = dots[names(dots) %in% compile_args] +set_pred( + model = "mlp", + eng = "nnet", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) -} -mlp_num_weights <- function(p, hidden_units, classes) - ((p+1) * hidden_units) + ((hidden_units+1) * classes) +) +set_pred( + model = "mlp", + eng = "nnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) + ) +) +set_pred( + model = "mlp", + eng = "nnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = nnet_softmax, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) +set_pred( + model = "mlp", + eng = "nnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index b9aa2f40c..71e367489 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -322,7 +322,3 @@ check_glmnet_lambda <- function(dat, object) { dat } -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(c("group", ".pred")) diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 186291003..7b7b16b3b 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -1,138 +1,228 @@ +set_new_model("multinom_reg") -multinom_reg_arg_key <- data.frame( - glmnet = c( "lambda", "alpha"), - spark = c("reg_param", "elastic_net_param"), - keras = c( "decay", NA), - stringsAsFactors = FALSE, - row.names = c("penalty", "mixture") +set_model_mode("multinom_reg", "classification") + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "glmnet") +set_dependency("multinom_reg", "glmnet", "glmnet") + +set_model_arg( + model = "multinom_reg", + eng = "glmnet", + parsnip = "penalty", + original = "lambda", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "multinom_reg", + eng = "glmnet", + parsnip = "mixture", + original = "alpha", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y", "weights"), + func = c(pkg = "glmnet", fun = "glmnet"), + defaults = list(family = "multinomial") + ) +) + + +set_pred( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "class", + value = list( + pre = check_glmnet_lambda, + post = organize_multnet_class, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "class", + s = quote(object$spec$args$penalty) + ) + ) ) -multinom_reg_modes <- "classification" +set_pred( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "prob", + value = list( + pre = check_glmnet_lambda, + post = organize_multnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)), + type = "response", + s = quote(object$spec$args$penalty) + ) + ) +) -multinom_reg_engines <- data.frame( - glmnet = TRUE, - spark = TRUE, - keras = TRUE, - row.names = c("classification") +set_pred( + model = "multinom_reg", + eng = "glmnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newx = quote(as.matrix(new_data)) + ) + ) ) # ------------------------------------------------------------------------------ -multinom_reg_glmnet_data <- - list( - libs = "glmnet", - fit = list( - interface = "matrix", - protect = c("x", "y", "weights"), - func = c(pkg = "glmnet", fun = "glmnet"), - defaults = - list( - family = "multinomial" - ) - ), - class = list( - pre = check_glmnet_lambda, - post = organize_multnet_class, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "class", - s = quote(object$spec$args$penalty) - ) - ), - classprob = list( - pre = check_glmnet_lambda, - post = organize_multnet_prob, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)), - type = "response", - s = quote(object$spec$args$penalty) - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newx = quote(as.matrix(new_data)) - ) - ) +set_model_engine("multinom_reg", "classification", "spark") +set_dependency("multinom_reg", "spark", "sparklyr") + +set_model_arg( + model = "multinom_reg", + eng = "spark", + parsnip = "penalty", + original = "reg_param", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = TRUE +) + +set_model_arg( + model = "multinom_reg", + eng = "spark", + parsnip = "mixture", + original = "elastic_net_param", + func = list(pkg = "dials", fun = "mixture"), + has_submodel = FALSE +) + +set_fit( + model = "multinom_reg", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "weight_col"), + func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), + defaults = list(family = "multinomial") ) +) -multinom_reg_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "weight_col"), - func = c(pkg = "sparklyr", fun = "ml_logistic_regression"), - defaults = - list( - family = "multinomial" - ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) +set_pred( + model = "multinom_reg", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) +) -multinom_reg_keras_data <- - list( - libs = c("keras", "magrittr"), - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "parsnip", fun = "keras_mlp"), - defaults = list(hidden_units = 1, act = "linear") - ), - class = list( - pre = NULL, - post = function(x, object) { - object$lvl[x + 1] - }, - func = c(pkg = "keras", fun = "predict_classes"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - x <- as_tibble(x) - colnames(x) <- object$lvl - x - }, - func = c(pkg = "keras", fun = "predict_proba"), - args = - list( - object = quote(object$fit), - x = quote(as.matrix(new_data)) - ) - ) +set_pred( + model = "multinom_reg", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list( + x = quote(object$fit), + dataset = quote(new_data) + ) ) +) + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "keras") +set_dependency("multinom_reg", "keras", "keras") +set_dependency("multinom_reg", "keras", "magrittr") + +set_model_arg( + model = "multinom_reg", + eng = "keras", + parsnip = "decay", + original = "decay", + func = list(pkg = "dials", fun = "weight_decay"), + has_submodel = FALSE +) + + +set_fit( + model = "multinom_reg", + eng = "keras", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "parsnip", fun = "keras_mlp"), + defaults = list(hidden_units = 1, act = "linear") + ) +) + +set_pred( + model = "multinom_reg", + eng = "keras", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = function(x, object) { + object$lvl[x + 1] + }, + func = c(pkg = "keras", fun = "predict_classes"), + args = + list(object = quote(object$fit), + x = quote(as.matrix(new_data))) + ) +) + +set_pred( + model = "multinom_reg", + eng = "keras", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + x <- as_tibble(x) + colnames(x) <- object$lvl + x + }, + func = c(pkg = "keras", fun = "predict_proba"), + args = + list(object = quote(object$fit), + x = quote(as.matrix(new_data))) + ) +) diff --git a/R/nearest_neighbor.R b/R/nearest_neighbor.R index 52c9a9f0e..4f0cbb165 100644 --- a/R/nearest_neighbor.R +++ b/R/nearest_neighbor.R @@ -29,7 +29,8 @@ #' `"classification"`. #' #' @param neighbors A single integer for the number of neighbors -#' to consider (often called `k`). +#' to consider (often called `k`). For \pkg{kknn}, a value of 5 +#' is used if `neighbors` is not specified. #' #' @param weight_func A *single* character for the type of kernel function used #' to weight distances between samples. Valid choices are: `"rectangular"`, @@ -54,7 +55,7 @@ #' #' \pkg{kknn} (classification or regression) #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(), "kknn")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::nearest_neighbor(mode = "regression"), "kknn")} #' #' @note #' For `kknn`, the underlying modeling function used is a restricted @@ -148,13 +149,32 @@ check_args.nearest_neighbor <- function(object) { args <- lapply(object$args, rlang::eval_tidy) - if(is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { + if (is.numeric(args$neighbors) && !positive_int_scalar(args$neighbors)) { stop("`neighbors` must be a length 1 positive integer.", call. = FALSE) } - if(is.character(args$weight_func) && length(args$weight_func) > 1) { + if (is.character(args$weight_func) && length(args$weight_func) > 1) { stop("The length of `weight_func` must be 1.", call. = FALSE) } invisible(object) } + +# ------------------------------------------------------------------------------ + +#' @export +translate.nearest_neighbor <- function(x, engine = x$engine, ...) { + if (is.null(engine)) { + message("Used `engine = 'kknn'` for translation.") + engine <- "kknn" + } + x <- translate.default(x, engine, ...) + + if (engine == "kknn") { + if (!any(names(x$method$fit$args) == "ks") || + is_missing_arg(x$method$fit$args$ks)) { + x$method$fit$args$ks <- 5 + } + } + x +} diff --git a/R/nearest_neighbor_data.R b/R/nearest_neighbor_data.R index 0191d8614..2a85f70d0 100644 --- a/R/nearest_neighbor_data.R +++ b/R/nearest_neighbor_data.R @@ -1,90 +1,171 @@ -nearest_neighbor_arg_key <- data.frame( - kknn = c("ks", "kernel", "distance"), - row.names = c("neighbors", "weight_func", "dist_power"), - stringsAsFactors = FALSE + +set_new_model("nearest_neighbor") + +set_model_mode("nearest_neighbor", "classification") +set_model_mode("nearest_neighbor", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("nearest_neighbor", "classification", "kknn") +set_model_engine("nearest_neighbor", "regression", "kknn") +set_dependency("nearest_neighbor", "kknn", "kknn") + +set_model_arg( + model = "nearest_neighbor", + eng = "kknn", + parsnip = "neighbors", + original = "ks", + func = list(pkg = "dials", fun = "neighbors"), + has_submodel = FALSE +) +set_model_arg( + model = "nearest_neighbor", + eng = "kknn", + parsnip = "weight_func", + original = "kernel", + func = list(pkg = "dials", fun = "weight_func"), + has_submodel = FALSE +) +set_model_arg( + model = "nearest_neighbor", + eng = "kknn", + parsnip = "dist_power", + original = "distance", + func = list(pkg = "dials", fun = "distance"), + has_submodel = FALSE ) -nearest_neighbor_modes <- c("classification", "regression", "unknown") +set_fit( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "kknn", fun = "train.kknn"), + defaults = list() + ) +) -nearest_neighbor_engines <- data.frame( - kknn = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_fit( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data"), + func = c(pkg = "kknn", fun = "train.kknn"), + defaults = list() + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + type = "numeric", + value = list( + # seems unnecessary here as the predict_numeric catches it based on the + # model mode + pre = function(x, object) { + if (object$fit$response != "continuous") { + stop("`kknn` model does not appear to use numeric predictions. Was ", + "the model fit with a continuous response variable?", + call. = FALSE) + } + x + }, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) + +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "class", + value = list( + pre = function(x, object) { + if (!(object$fit$response %in% c("ordinal", "nominal"))) { + stop("`kknn` model does not appear to use class predictions. Was ", + "the model fit with a factor response variable?", + call. = FALSE) + } + x + }, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "raw" + ) + ) +) + +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "prob", + value = list( + pre = function(x, object) { + if (!(object$fit$response %in% c("ordinal", "nominal"))) { + stop("`kknn` model does not appear to use class predictions. Was ", + "the model fit with a factor response variable?", + call. = FALSE) + } + x + }, + post = function(result, object) as_tibble(result), + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) -nearest_neighbor_kknn_data <- - list( - libs = "kknn", - fit = list( - interface = "formula", - protect = c("formula", "data", "kmax"), # kmax is not allowed - func = c(pkg = "kknn", fun = "train.kknn"), - defaults = list() - ), - numeric = list( - # seems unnecessary here as the predict_numeric catches it based on the - # model mode - pre = function(x, object) { - if (object$fit$response != "continuous") { - stop("`kknn` model does not appear to use numeric predictions. Was ", - "the model fit with a continuous response variable?", - call. = FALSE) - } - x - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - class = list( - pre = function(x, object) { - if (!(object$fit$response %in% c("ordinal", "nominal"))) { - stop("`kknn` model does not appear to use class predictions. Was ", - "the model fit with a factor response variable?", - call. = FALSE) - } - x - }, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "raw" - ) - ), - classprob = list( - pre = function(x, object) { - if (!(object$fit$response %in% c("ordinal", "nominal"))) { - stop("`kknn` model does not appear to use class predictions. Was ", - "the model fit with a factor response variable?", - call. = FALSE) - } - x - }, - post = function(result, object) as_tibble(result), - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_pred( + model = "nearest_neighbor", + eng = "kknn", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) ) +) diff --git a/R/nullmodel.R b/R/nullmodel.R index c37ecfead..10772562d 100644 --- a/R/nullmodel.R +++ b/R/nullmodel.R @@ -160,6 +160,7 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) { #' @export null_model <- function(mode = "classification") { + null_model_modes <- unique(get_model_env()$null_model$mode) # Check for correct mode if (!(mode %in% null_model_modes)) stop("`mode` should be one of: ", diff --git a/R/nullmodel_data.R b/R/nullmodel_data.R index 80380a98a..1aba4e6bf 100644 --- a/R/nullmodel_data.R +++ b/R/nullmodel_data.R @@ -1,72 +1,128 @@ -null_model_arg_key <- data.frame( - parsnip = NULL, - row.names = NULL, - stringsAsFactors = FALSE +set_new_model("null_model") + +set_model_mode("null_model", "classification") +set_model_mode("null_model", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("null_model", "classification", "parsnip") +set_model_engine("null_model", "regression", "parsnip") +set_dependency("null_model", "parsnip", "parsnip") + +set_fit( + model = "null_model", + eng = "parsnip", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(fun = "nullmodel"), + defaults = list() + ) ) -null_model_modes <- c("classification", "regression", "unknown") +set_fit( + model = "null_model", + eng = "parsnip", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(fun = "nullmodel"), + defaults = list() + ) +) -null_model_engines <- data.frame( - parsnip = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_pred( + model = "null_model", + eng = "parsnip", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) + ) ) -# ------------------------------------------------------------------------------ +set_pred( + model = "null_model", + eng = "parsnip", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "numeric" + ) + ) +) -null_model_parsnip_data <- - list( - libs = "parsnip", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(fun = "nullmodel"), - defaults = list() - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "class" - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - str(as_tibble(x)) - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "prob" - ) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "numeric" - ) - ), - raw = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - new_data = quote(new_data), - type = "raw" - ) +set_pred( + model = "null_model", + eng = "parsnip", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "class" ) - ) + ) +) + +set_pred( + model = "null_model", + eng = "parsnip", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + str(as_tibble(x)) + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + model = "null_model", + eng = "parsnip", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + new_data = quote(new_data), + type = "raw" + ) + ) +) + diff --git a/R/predict.R b/R/predict.R index 7cea3e0b7..414096d1a 100644 --- a/R/predict.R +++ b/R/predict.R @@ -154,9 +154,6 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) res } -pred_types <- - c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile") - #' @importFrom glue glue_collapse check_pred_type <- function(object, type) { if (is.null(type)) { diff --git a/R/predict_class.R b/R/predict_class.R index 0e292a8b3..91c84fc02 100644 --- a/R/predict_class.R +++ b/R/predict_class.R @@ -13,7 +13,7 @@ predict_class.model_fit <- function(object, new_data, ...) { stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "class")) + if (!any(names(object$spec$method$pred) == "class")) stop("No class prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -24,17 +24,17 @@ predict_class.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$class$pre)) - new_data <- object$spec$method$class$pre(new_data, object) + if (!is.null(object$spec$method$pred$class$pre)) + new_data <- object$spec$method$pred$class$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$class) + pred_call <- make_pred_call(object$spec$method$pred$class) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$class$post)) { - res <- object$spec$method$class$post(res, object) + if (!is.null(object$spec$method$pred$class$post)) { + res <- object$spec$method$pred$class$post(res, object) } # coerce levels to those in `object` diff --git a/R/predict_classprob.R b/R/predict_classprob.R index d902e7735..1dbbe5328 100644 --- a/R/predict_classprob.R +++ b/R/predict_classprob.R @@ -10,7 +10,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) { stop("`predict.model_fit()` is for predicting factor outcomes.", call. = FALSE) - if (!any(names(object$spec$method) == "classprob")) + if (!any(names(object$spec$method$pred) == "prob")) stop("No class probability module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -21,17 +21,17 @@ predict_classprob.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$classprob$pre)) - new_data <- object$spec$method$classprob$pre(new_data, object) + if (!is.null(object$spec$method$pred$prob$pre)) + new_data <- object$spec$method$pred$prob$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$classprob) + pred_call <- make_pred_call(object$spec$method$pred$prob) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$classprob$post)) { - res <- object$spec$method$classprob$post(res, object) + if (!is.null(object$spec$method$pred$prob$post)) { + res <- object$spec$method$pred$prob$post(res, object) } # check and sort names diff --git a/R/predict_interval.R b/R/predict_interval.R index f38838e00..b492ef866 100644 --- a/R/predict_interval.R +++ b/R/predict_interval.R @@ -10,7 +10,7 @@ # @export predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$confint)) + if (is.null(object$spec$method$pred$conf_int)) stop("No confidence interval method defined for this ", "engine.", call. = FALSE) @@ -22,19 +22,19 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$confint$pre)) - new_data <- object$spec$method$confint$pre(new_data, object) + if (!is.null(object$spec$method$pred$conf_int$pre)) + new_data <- object$spec$method$pred$conf_int$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$confint$extras <- + object$spec$method$pred$conf_int$extras <- list(level = level, std_error = std_error) - pred_call <- make_pred_call(object$spec$method$confint) + pred_call <- make_pred_call(object$spec$method$pred$conf_int) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$confint$post)) { - res <- object$spec$method$confint$post(res, object) + if (!is.null(object$spec$method$pred$conf_int$post)) { + res <- object$spec$method$pred$conf_int$post(res, object) } attr(res, "level") <- level @@ -59,7 +59,7 @@ predict_confint <- function(object, ...) # @export predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) { - if (is.null(object$spec$method$predint)) + if (is.null(object$spec$method$pred$pred_int)) stop("No prediction interval method defined for this ", "engine.", call. = FALSE) @@ -71,20 +71,20 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$predint$pre)) - new_data <- object$spec$method$predint$pre(new_data, object) + if (!is.null(object$spec$method$pred$pred_int$pre)) + new_data <- object$spec$method$pred$pred_int$pre(new_data, object) # create prediction call # Pass some extra arguments to be used in post-processor - object$spec$method$predint$extras <- + object$spec$method$pred$pred_int$extras <- list(level = level, std_error = std_error) - pred_call <- make_pred_call(object$spec$method$predint) + pred_call <- make_pred_call(object$spec$method$pred$pred_int) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$predint$post)) { - res <- object$spec$method$predint$post(res, object) + if (!is.null(object$spec$method$pred$pred_int$post)) { + res <- object$spec$method$pred$pred_int$post(res, object) } attr(res, "level") <- level diff --git a/R/predict_numeric.R b/R/predict_numeric.R index 3a509546b..970107049 100644 --- a/R/predict_numeric.R +++ b/R/predict_numeric.R @@ -11,7 +11,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) { "Use `predict_class()` or `predict_classprob()` for ", "classification models.", call. = FALSE) - if (!any(names(object$spec$method) == "numeric")) + if (!any(names(object$spec$method$pred) == "numeric")) stop("No prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -22,17 +22,17 @@ predict_numeric.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$numeric$pre)) - new_data <- object$spec$method$numeric$pre(new_data, object) + if (!is.null(object$spec$method$pred$numeric$pre)) + new_data <- object$spec$method$pred$numeric$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$numeric) + pred_call <- make_pred_call(object$spec$method$pred$numeric) res <- eval_tidy(pred_call) # post-process the predictions - if (!is.null(object$spec$method$numeric$post)) { - res <- object$spec$method$numeric$post(res, object) + if (!is.null(object$spec$method$pred$numeric$post)) { + res <- object$spec$method$pred$numeric$post(res, object) } if (is.vector(res)) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R index 698ddb4c8..19d22654f 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -9,7 +9,7 @@ predict_quantile.model_fit <- function (object, new_data, quantile = (1:9)/10, ...) { - if (is.null(object$spec$method$quantile)) + if (is.null(object$spec$method$pred$quantile)) stop("No quantile prediction method defined for this ", "engine.", call. = FALSE) @@ -21,18 +21,18 @@ predict_quantile.model_fit <- new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$quantile$pre)) - new_data <- object$spec$method$quantile$pre(new_data, object) + if (!is.null(object$spec$method$pred$quantile$pre)) + new_data <- object$spec$method$pred$quantile$pre(new_data, object) # Pass some extra arguments to be used in post-processor - object$spec$method$quantile$args$p <- quantile - pred_call <- make_pred_call(object$spec$method$quantile) + object$spec$method$pred$quantile$args$p <- quantile + pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) # post-process the predictions - if(!is.null(object$spec$method$quantile$post)) { - res <- object$spec$method$quantile$post(res, object) + if(!is.null(object$spec$method$pred$quantile$post)) { + res <- object$spec$method$pred$quantile$post(res, object) } res diff --git a/R/predict_raw.R b/R/predict_raw.R index 315c9dd0a..2d859f8a7 100644 --- a/R/predict_raw.R +++ b/R/predict_raw.R @@ -4,17 +4,17 @@ # @export predict_raw.model_fit # @export predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { - protected_args <- names(object$spec$method$raw$args) + protected_args <- names(object$spec$method$pred$raw$args) dup_args <- names(opts) %in% protected_args if (any(dup_args)) { opts <- opts[[!dup_args]] } if (length(opts) > 0) { - object$spec$method$raw$args <- - c(object$spec$method$raw$args, opts) + object$spec$method$pred$raw$args <- + c(object$spec$method$pred$raw$args, opts) } - if (!any(names(object$spec$method) == "raw")) + if (!any(names(object$spec$method$pred) == "raw")) stop("No raw prediction module defined for this model.", call. = FALSE) if (inherits(object$fit, "try-error")) { @@ -25,11 +25,11 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$raw$pre)) - new_data <- object$spec$method$raw$pre(new_data, object) + if (!is.null(object$spec$method$pred$raw$pre)) + new_data <- object$spec$method$pred$raw$pre(new_data, object) # create prediction call - pred_call <- make_pred_call(object$spec$method$raw) + pred_call <- make_pred_call(object$spec$method$pred$raw) res <- eval_tidy(pred_call) diff --git a/R/rand_forest_data.R b/R/rand_forest_data.R index 65eb84864..77ddea12e 100644 --- a/R/rand_forest_data.R +++ b/R/rand_forest_data.R @@ -1,24 +1,3 @@ - -rand_forest_arg_key <- data.frame( - randomForest = c("mtry", "ntree", "nodesize"), - ranger = c("mtry", "num.trees", "min.node.size"), - spark = - c("feature_subset_strategy", "num_trees", "min_instances_per_node"), - stringsAsFactors = FALSE, - row.names = c("mtry", "trees", "min_n") -) - -rand_forest_modes <- c("classification", "regression", "unknown") - -rand_forest_engines <- data.frame( - ranger = c(TRUE, TRUE, FALSE), - randomForest = c(TRUE, TRUE, FALSE), - spark = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") -) - -# ------------------------------------------------------------------------------ - # wrappers for ranger ranger_class_pred <- function(results, object) { @@ -32,7 +11,7 @@ ranger_class_pred <- #' @importFrom stats qnorm ranger_num_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qnorm(hf_lvl, lower.tail = FALSE) res <- @@ -44,12 +23,12 @@ ranger_num_confint <- function(object, new_data, ...) { res$.pred_upper <- res$.pred + const * std_error res$.pred <- NULL - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res$.std_error <- std_error res } ranger_class_confint <- function(object, new_data, ...) { - hf_lvl <- (1 - object$spec$method$confint$extras$level)/2 + hf_lvl <- (1 - object$spec$method$pred$conf_int$extras$level)/2 const <- qnorm(hf_lvl, lower.tail = FALSE) pred <- predict(object$fit, data = new_data, type = "response", ...)$predictions @@ -73,21 +52,21 @@ ranger_class_confint <- function(object, new_data, ...) { col_names <- paste0(c(".pred_lower_", ".pred_upper_"), lvl) res <- res[, col_names] - if(object$spec$method$confint$extras$std_error) + if (object$spec$method$pred$conf_int$extras$std_error) res <- bind_cols(res, std_error) res } ranger_confint <- function(object, new_data, ...) { - if(object$fit$forest$treetype == "Regression") { + if (object$fit$forest$treetype == "Regression") { res <- ranger_num_confint(object, new_data, ...) } else { - if(object$fit$forest$treetype == "Probability estimation") { + if (object$fit$forest$treetype == "Probability estimation") { res <- ranger_class_confint(object, new_data, ...) } else { - stop ("Cannot compute confidence intervals for a ranger forest ", - "of type ", object$fit$forest$treetype, ".", call. = FALSE) + stop("Cannot compute confidence intervals for a ranger forest ", + "of type ", object$fit$forest$treetype, ".", call. = FALSE) } } res @@ -95,186 +74,454 @@ ranger_confint <- function(object, new_data, ...) { # ------------------------------------------------------------------------------ +set_new_model("rand_forest") -rand_forest_ranger_data <- - list( - libs = "ranger", - fit = list( - interface = "formula", - protect = c("formula", "data", "case.weights"), - func = c(pkg = "ranger", fun = "ranger"), - defaults = - list( - num.threads = 1, - verbose = FALSE, - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = function(results, object) results$predictions, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - class = list( - pre = NULL, - post = ranger_class_pred, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - type = "response", - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - classprob = list( - pre = function(x, object) { - if (object$fit$forest$treetype != "Probability estimation") - stop("`ranger` model does not appear to use class probabilities. Was ", - "the model fit with `probability = TRUE`?", - call. = FALSE) - x - }, - post = function(x, object) { - x <- x$prediction - as_tibble(x) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10^5, 1)), - verbose = FALSE - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ), - confint = list( - pre = NULL, - post = NULL, - func = c(fun = "ranger_confint"), - args = - list( - object = quote(object), - new_data = quote(new_data), - seed = expr(sample.int(10^5, 1)) - ) - ) +set_model_mode("rand_forest", "classification") +set_model_mode("rand_forest", "regression") + +# ------------------------------------------------------------------------------ +# ranger components + +set_model_engine("rand_forest", "classification", "ranger") +set_model_engine("rand_forest", "regression", "ranger") +set_dependency("rand_forest", "ranger", "ranger") + +set_model_arg( + model = "rand_forest", + eng = "ranger", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "ranger", + parsnip = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "ranger", + parsnip = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "ranger", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "case.weights"), + func = c(pkg = "ranger", fun = "ranger"), + defaults = + list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10 ^ 5, 1)) + ) ) +) -rand_forest_randomForest_data <- - list( - libs = "randomForest", - fit = list( - interface = "data.frame", - protect = c("x", "y"), - func = c(pkg = "randomForest", fun = "randomForest"), - defaults = - list() - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = function(x, object) { - as_tibble(as.data.frame(x)) - }, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "prob" - ) - ), - raw = list( - pre = NULL, - func = c(fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) +set_fit( + model = "rand_forest", + eng = "ranger", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "case.weights"), + func = c(pkg = "ranger", fun = "ranger"), + defaults = + list( + num.threads = 1, + verbose = FALSE, + seed = expr(sample.int(10 ^ 5, 1)) + ) ) +) +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = ranger_class_pred, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) -rand_forest_spark_data <- - list( - libs = "sparklyr", - fit = list( - interface = "formula", - protect = c("x", "formula", "type"), - func = c(pkg = "sparklyr", fun = "ml_random_forest"), - defaults = - list( - seed = expr(sample.int(10^5, 1)) - ) - ), - numeric = list( - pre = NULL, - post = format_spark_num, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "classification", + type = "prob", + value = list( + pre = function(x, object) { + if (object$fit$forest$treetype != "Probability estimation") + stop( + "`ranger` model does not appear to use class probabilities. Was ", + "the model fit with `probability = TRUE`?", + call. = FALSE ) - ), - class = list( - pre = NULL, - post = format_spark_class, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ), - classprob = list( - pre = NULL, - post = format_spark_probs, - func = c(pkg = "sparklyr", fun = "ml_predict"), - args = - list( - x = quote(object$fit), - dataset = quote(new_data) - ) - ) + x + }, + post = function(x, object) { + x <- x$prediction + as_tibble(x) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "classification", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "ranger_confint"), + args = + list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) ) +) + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = function(results, object) + results$predictions, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + type = "response", + seed = expr(sample.int(10 ^ 5, 1)), + verbose = FALSE + ) + ) +) + + +set_pred( + mod = "rand_forest", + eng = "ranger", + mode = "regression", + type = "conf_int", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "ranger_confint"), + args = + list( + object = quote(object), + new_data = quote(new_data), + seed = expr(sample.int(10^5, 1)) + ) + ) +) +set_pred( + model = "rand_forest", + eng = "ranger", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + data = quote(new_data), + seed = expr(sample.int(10 ^ 5, 1)) + ) + ) +) + +# ------------------------------------------------------------------------------ +# randomForest components + +set_model_engine("rand_forest", "classification", "randomForest") +set_model_engine("rand_forest", "regression", "randomForest") +set_dependency("rand_forest", "randomForest", "randomForest") + +set_model_arg( + model = "rand_forest", + eng = "randomForest", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "randomForest", + parsnip = "trees", + original = "ntree", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "randomForest", + parsnip = "min_n", + original = "nodesize", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + value = list( + interface = "data.frame", + protect = c("x", "y"), + func = c(pkg = "randomForest", fun = "randomForest"), + defaults = + list() + ) +) + +set_fit( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + value = list( + interface = "data.frame", + protect = c("x", "y"), + func = c(pkg = "randomForest", fun = "randomForest"), + defaults = + list() + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(x, object) { + as_tibble(as.data.frame(x)) + }, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + model = "rand_forest", + eng = "randomForest", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list(object = quote(object$fit), + newdata = quote(new_data)) + ) +) + +# ------------------------------------------------------------------------------ +# spark components + +set_model_engine("rand_forest", "classification", "spark") +set_model_engine("rand_forest", "regression", "spark") +set_dependency("rand_forest", "spark", "sparklyr") + +set_model_arg( + model = "rand_forest", + eng = "spark", + parsnip = "mtry", + original = "feature_subset_strategy", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "spark", + parsnip = "trees", + original = "num_trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE +) +set_model_arg( + model = "rand_forest", + eng = "spark", + parsnip = "min_n", + original = "min_instances_per_node", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE +) + +set_fit( + model = "rand_forest", + eng = "spark", + mode = "classification", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_random_forest"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_fit( + model = "rand_forest", + eng = "spark", + mode = "regression", + value = list( + interface = "formula", + protect = c("x", "formula", "type"), + func = c(pkg = "sparklyr", fun = "ml_random_forest"), + defaults = list(seed = expr(sample.int(10 ^ 5, 1))) + ) +) + +set_pred( + model = "rand_forest", + eng = "spark", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = format_spark_num, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "spark", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = format_spark_class, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) + +set_pred( + model = "rand_forest", + eng = "spark", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = format_spark_probs, + func = c(pkg = "sparklyr", fun = "ml_predict"), + args = + list(x = quote(object$fit), + dataset = quote(new_data)) + ) +) diff --git a/R/surv_reg.R b/R/surv_reg.R index 33ab8b47e..878e20a9a 100644 --- a/R/surv_reg.R +++ b/R/surv_reg.R @@ -36,7 +36,7 @@ #' The model can be created using the `fit()` function using the #' following _engines_: #' \itemize{ -#' \item \pkg{R}: `"flexsurv"`, `"survreg"` (the default) +#' \item \pkg{R}: `"flexsurv"`, `"survival"` (the default) #' } #' #' @section Engine Details: @@ -49,9 +49,9 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "flexsurv")} #' -#' \pkg{survreg} +#' \pkg{survival} #' -#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survreg")} +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::surv_reg(), "survival")} #' #' Note that `model = TRUE` is needed to produce quantile #' predictions when there is a stratification variable and can be @@ -143,8 +143,8 @@ update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) { #' @export translate.surv_reg <- function(x, engine = x$engine, ...) { if (is.null(engine)) { - message("Used `engine = 'survreg'` for translation.") - engine <- "survreg" + message("Used `engine = 'survival'` for translation.") + engine <- "survival" } x <- translate.default(x, engine, ...) x @@ -171,7 +171,7 @@ check_args.surv_reg <- function(object) { #' @importFrom stats setNames #' @importFrom dplyr mutate survreg_quant <- function(results, object) { - pctl <- object$spec$method$quantile$args$p + pctl <- object$spec$method$pred$quantile$args$p n <- nrow(results) p <- ncol(results) results <- @@ -208,8 +208,3 @@ flexsurv_quant <- function(results, object) { results <- map(results, setNames, c(".quantile", ".pred", ".pred_lower", ".pred_upper")) } -# ------------------------------------------------------------------------------ - -#' @importFrom utils globalVariables -utils::globalVariables(".label") - diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 046b0a469..e52fdb2f9 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -1,120 +1,130 @@ -surv_reg_arg_key <- data.frame( - survreg = c("dist"), - flexsurv = c("dist"), - stringsAsFactors = FALSE, - row.names = c("dist") -) +set_new_model("surv_reg") +set_model_mode("surv_reg", "regression") -surv_reg_modes <- "regression" +# ------------------------------------------------------------------------------ -surv_reg_engines <- data.frame( - survreg = TRUE, - flexsurv = TRUE, - stringsAsFactors = TRUE, - row.names = c("regression") -) +set_model_engine("surv_reg", "regression", "flexsurv") +set_dependency("surv_reg", "flexsurv", "flexsurv") +set_dependency("surv_reg", "flexsurv", "survival") -# ------------------------------------------------------------------------------ +set_model_arg( + model = "surv_reg", + eng = "flexsurv", + parsnip = "dist", + original = "dist", + func = list(pkg = "dials", fun = "dist"), + has_submodel = FALSE +) -surv_reg_flexsurv_data <- - list( - libs = c("survival", "flexsurv"), - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "flexsurv", fun = "flexsurvreg"), - defaults = list() - ), - numeric = list( - pre = NULL, - post = flexsurv_mean, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "mean" - ) - ), - quantile = list( - pre = NULL, - post = flexsurv_quant, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - quantiles = expr(quantile) - ) - ) +set_fit( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "flexsurv", fun = "flexsurvreg"), + defaults = list() ) +) -# ------------------------------------------------------------------------------ +set_pred( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = flexsurv_mean, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "mean" + ) + ) +) -surv_reg_survreg_data <- - list( - libs = c("survival"), - fit = list( - interface = "formula", - protect = c("formula", "data", "weights"), - func = c(pkg = "survival", fun = "survreg"), - defaults = list(model = TRUE) - ), - numeric = list( - pre = NULL, - post = NULL, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "response" - ) - ), - quantile = list( - pre = NULL, - post = survreg_quant, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - p = expr(quantile) - ) - ) +set_pred( + model = "surv_reg", + eng = "flexsurv", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = flexsurv_quant, + func = c(fun = "summary"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + quantiles = expr(quantile) + ) ) +) # ------------------------------------------------------------------------------ -# surv_reg_stan_data <- -# list( -# libs = c("brms"), -# fit = list( -# interface = "formula", -# protect = c("formula", "data", "weights"), -# func = c(pkg = "brms", fun = "brm"), -# defaults = list( -# family = expr(brms::weibull()), -# seed = expr(sample.int(10^5, 1)) -# ) -# ), -# numeric = list( -# pre = NULL, -# post = function(results, object) { -# tibble::as_tibble(results) %>% -# dplyr::select(Estimate) %>% -# setNames(".pred") -# }, -# func = c(fun = "predict"), -# args = -# list( -# object = expr(object$fit), -# newdata = expr(new_data), -# type = "response" -# ) -# ) -# ) +set_model_engine("surv_reg", "regression", "survival") +set_dependency("surv_reg", "survival", "survival") +set_model_arg( + model = "surv_reg", + eng = "survival", + parsnip = "dist", + original = "dist", + func = list(pkg = "dials", fun = "dist"), + has_submodel = FALSE +) + +set_fit( + model = "surv_reg", + eng = "survival", + mode = "regression", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "survival", fun = "survreg"), + defaults = list(model = TRUE) + ) +) + +set_pred( + model = "surv_reg", + eng = "survival", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "surv_reg", + eng = "survival", + mode = "regression", + type = "quantile", + value = list( + pre = NULL, + post = survreg_quant, + func = c(fun = "predict"), + args = + list( + object = expr(object$fit), + newdata = expr(new_data), + type = "quantile", + p = expr(quantile) + ) + ) +) diff --git a/R/svm_poly_data.R b/R/svm_poly_data.R index 04b0bc55e..565c52d87 100644 --- a/R/svm_poly_data.R +++ b/R/svm_poly_data.R @@ -1,69 +1,150 @@ -svm_poly_arg_key <- data.frame( - kernlab = c( "C", "degree", "scale", "epsilon"), - row.names = c("cost", "degree", "scale_factor", "margin"), - stringsAsFactors = FALSE +set_new_model("svm_poly") + +set_model_mode("svm_poly", "classification") +set_model_mode("svm_poly", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("svm_poly", "classification", "kernlab") +set_model_engine("svm_poly", "regression", "kernlab") +set_dependency("svm_poly", "kernlab", "kernlab") + +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "cost", + original = "C", + func = list(pkg = "dials", fun = "cost"), + has_submodel = FALSE ) -svm_poly_modes <- c("classification", "regression", "unknown") +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "degree", + original = "degree", + func = list(pkg = "dials", fun = "degree"), + has_submodel = FALSE +) -svm_poly_engines <- data.frame( - kernlab = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "scale_factor", + original = "scale", + func = list(pkg = "dials", fun = "scale_factor"), + has_submodel = FALSE +) +set_model_arg( + model = "svm_poly", + eng = "kernlab", + parsnip = "margin", + original = "epsilon", + func = list(pkg = "dials", fun = "margin"), + has_submodel = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "polydot") + ) +) -svm_poly_kernlab_data <- - list( - libs = "kernlab", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "kernlab", fun = "ksvm"), - defaults = list( - kernel = "polydot" +set_fit( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "polydot") + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" ) - ), - numeric = list( - pre = NULL, - post = svm_reg_post, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(result, object) as_tibble(result), - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) - ), - raw = list( - pre = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" + ) + ) +) + +set_pred( + model = "svm_poly", + eng = "kernlab", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + diff --git a/R/svm_rbf_data.R b/R/svm_rbf_data.R index fdb12727d..3c905e856 100644 --- a/R/svm_rbf_data.R +++ b/R/svm_rbf_data.R @@ -1,69 +1,142 @@ -svm_rbf_arg_key <- data.frame( - kernlab = c( "C", "sigma", "epsilon"), - row.names = c("cost", "rbf_sigma", "margin"), - stringsAsFactors = FALSE +set_new_model("svm_rbf") + +set_model_mode("svm_rbf", "classification") +set_model_mode("svm_rbf", "regression") + +# ------------------------------------------------------------------------------ + +set_model_engine("svm_rbf", "classification", "kernlab") +set_model_engine("svm_rbf", "regression", "kernlab") +set_dependency("svm_rbf", "kernlab", "kernlab") + +set_model_arg( + model = "svm_rbf", + eng = "kernlab", + parsnip = "cost", + original = "C", + func = list(pkg = "dials", fun = "cost"), + has_submodel = FALSE ) -svm_rbf_modes <- c("classification", "regression", "unknown") +set_model_arg( + model = "svm_rbf", + eng = "kernlab", + parsnip = "rbf_sigma", + original = "sigma", + func = list(pkg = "dials", fun = "rbf_sigma"), + has_submodel = FALSE +) -svm_rbf_engines <- data.frame( - kernlab = c(TRUE, TRUE, FALSE), - row.names = c("classification", "regression", "unknown") +set_model_arg( + model = "svm_rbf", + eng = "kernlab", + parsnip = "margin", + original = "epsilon", + func = list(pkg = "dials", fun = "margin"), + has_submodel = FALSE ) -# ------------------------------------------------------------------------------ +set_fit( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "rbfdot") + ) +) + +set_fit( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + value = list( + interface = "matrix", + protect = c("x", "y"), + func = c(pkg = "kernlab", fun = "ksvm"), + defaults = list(kernel = "rbfdot") + ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = svm_reg_post, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "regression", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "response" + ) + ) +) -svm_rbf_kernlab_data <- - list( - libs = "kernlab", - fit = list( - interface = "matrix", - protect = c("x", "y"), - func = c(pkg = "kernlab", fun = "ksvm"), - defaults = list( - kernel = "rbfdot" +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = function(result, object) as_tibble(result), + func = c(pkg = "kernlab", fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "probabilities" ) - ), - numeric = list( - pre = NULL, - post = svm_reg_post, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - class = list( - pre = NULL, - post = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "response" - ) - ), - classprob = list( - pre = NULL, - post = function(result, object) as_tibble(result), - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data), - type = "probabilities" - ) - ), - raw = list( - pre = NULL, - func = c(pkg = "kernlab", fun = "predict"), - args = - list( - object = quote(object$fit), - newdata = quote(new_data) - ) - ) ) +) + +set_pred( + model = "svm_rbf", + eng = "kernlab", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(pkg = "kernlab", fun = "predict"), + args = list(object = quote(object$fit), newdata = quote(new_data)) + ) +) + diff --git a/R/translate.R b/R/translate.R index fd014f770..3522544cc 100644 --- a/R/translate.R +++ b/R/translate.R @@ -19,11 +19,11 @@ #' This function can be useful when you need to understand how #' `parsnip` goes from a generic model specific to a model fitting #' function. -#' +#' #' **Note**: this function is used internally and users should only use it #' to understand what the underlying syntax would be. It should not be used -#' to modify the model specification. -#' +#' to modify the model specification. +#' #' @examples #' lm_spec <- linear_reg(penalty = 0.01) #' @@ -41,7 +41,7 @@ #' #' @export -translate <- function (x, ...) +translate <- function(x, ...) UseMethod("translate") #' @importFrom utils getFromNamespace @@ -53,29 +53,34 @@ translate.default <- function(x, engine = x$engine, ...) { if (is.null(engine)) stop("Please set an engine.", call. = FALSE) + mod_name <- specific_model(x) + x$engine <- engine x <- check_engine(x) + if (x$mode == "unknown") { + stop("Model code depends on the mode; please specify one.", call. = FALSE) + } + if (is.null(x$method)) - x <- get_method(x, engine, ...) + x$method <- get_model_spec(mod_name, x$mode, engine) - arg_key <- get_module(specific_model(x)) + arg_key <- get_args(mod_name, engine) # deharmonize primary arguments - actual_args <- deharmonize(x$args, arg_key, x$engine) + actual_args <- deharmonize(x$args, arg_key) # check secondary arguments to see if they are in the final # expression unless there are dots, warn if protected args are # being altered - eng_arg_key <- arg_key[[x$engine]] - x$eng_args <- check_eng_args(x$eng_args, x$method$fit, eng_arg_key) + x$eng_args <- check_eng_args(x$eng_args, x$method$fit, arg_key$original) # keep only modified args - modifed_args <- !vapply(actual_args, null_value, lgl(1)) + modifed_args <- !purrr::map_lgl(actual_args, null_value) actual_args <- actual_args[modifed_args] # look for defaults if not modified in other - if(length(x$method$fit$defaults) > 0) { + if (length(x$method$fit$defaults) > 0) { in_other <- names(x$method$fit$defaults) %in% names(x$eng_args) x$defaults <- x$method$fit$defaults[!in_other] } @@ -89,40 +94,6 @@ translate.default <- function(x, engine = x$engine, ...) { x } -get_method <- function(x, engine = x$engine, ...) { - check_empty_ellipse(...) - x$engine <- engine - x <- check_engine(x) - x$method <- get_model_info(x, x$engine) - x -} - - -get_module <- function(nm) { - arg_key <- try( - getFromNamespace( - paste0(nm, "_arg_key"), - ns = "parsnip" - ), - silent = TRUE - ) - if(inherits(arg_key, "try-error")) { - arg_key <- try( - get(paste0(nm, "_arg_key")), - silent = TRUE - ) - } - if(inherits(arg_key, "try-error")) { - stop( - "Cannot find the model code: `", - paste0(nm, "_arg_key"), - "`", call. = FALSE - ) - } - arg_key -} - - #' @export print.model_spec <- function(x, ...) { cat("Model Specification (", x$mode, ")\n\n", sep = "") @@ -144,3 +115,62 @@ check_mode <- function(object, lvl) { } object } + +# ------------------------------------------------------------------------------ +# new code for revised model data structures + +get_model_spec <- function(model, mode, engine) { + m_env <- get_model_env() + env_obj <- rlang::env_names(m_env) + env_obj <- grep(model, env_obj, value = TRUE) + + res <- list() + res$libs <- + rlang::env_get(m_env, paste0(model, "_pkgs")) %>% + purrr::pluck("pkg") %>% + purrr::pluck(1) + + res$fit <- + rlang::env_get(m_env, paste0(model, "_fit")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::pull(value) %>% + purrr::pluck(1) + + pred_code <- + rlang::env_get(m_env, paste0(model, "_predict")) %>% + dplyr::filter(mode == !!mode & engine == !!engine) %>% + dplyr::select(-engine, -mode) + + res$pred <- pred_code[["value"]] + names(res$pred) <- pred_code$type + + res +} + +get_args <- function(model, engine) { + m_env <- get_model_env() + rlang::env_get(m_env, paste0(model, "_args")) %>% + dplyr::filter(engine == !!engine) %>% + dplyr::select(-engine) +} + +# to replace harmonize +deharmonize <- function(args, key) { + if (length(args) == 0) + return(args) + parsn <- tibble(parsnip = names(args), order = seq_along(args)) + merged <- + dplyr::left_join(parsn, key, by = "parsnip") %>% + dplyr::arrange(order) + # TODO correct for bad merge? + + names(args) <- merged$original + args[!is.na(merged$original)] +} + +add_methods <- function(x, engine) { + x$engine <- engine + x <- check_engine(x) + x$method <- get_model_spec(specific_model(x), x$mode, x$engine) + x +} diff --git a/R/zzz.R b/R/zzz.R index dbee4e8cc..5978630f9 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,59 +1,59 @@ -## nocov start - -data_obj <- ls(pattern = "_data$") -data_obj <- data_obj[data_obj != "prepare_data"] - -#' @importFrom purrr map_dfr -#' @importFrom tibble as_tibble -data_names <- - map_dfr( - data_obj, - function(x) { - module <- names(get(x)) - if (length(module) > 1) { - module <- table(module) - module <- as_tibble(module) - module$object <- x - module - } else - module <- NULL - module - } - ) - -if(any(data_names$n > 1)) { - print(data_names[data_names$n > 1,]) - stop("Some models have duplicate module names.") -} -rm(data_names) - -# ------------------------------------------------------------------------------ - -engine_objects <- ls(pattern = "_engines$") -engine_objects <- engine_objects[engine_objects != "possible_engines"] - -#' @importFrom utils stack -get_engine_info <- function(x) { - y <- x - y <- get(y) - z <- stack(y) - z$mode <- rownames(y) - z$model <- gsub("_engines$", "", x) - z$object <- x - z <- z[z$values,] - z <- z[z$mode != "unknown",] - z$values <- NULL - names(z)[1] <- "engine" - z$engine <- as.character(z$engine) - z -} - -engine_info <- - purrr::map_df( - parsnip:::engine_objects, - get_engine_info -) - -rm(engine_objects) - -## nocov end +#' ## nocov start +#' +#' data_obj <- ls(pattern = "_data$") +#' data_obj <- data_obj[data_obj != "prepare_data"] +#' +#' #' @importFrom purrr map_dfr +#' #' @importFrom tibble as_tibble +#' data_names <- +#' map_dfr( +#' data_obj, +#' function(x) { +#' module <- names(get(x)) +#' if (length(module) > 1) { +#' module <- table(module) +#' module <- as_tibble(module) +#' module$object <- x +#' module +#' } else +#' module <- NULL +#' module +#' } +#' ) +#' +#' if(any(data_names$n > 1)) { +#' print(data_names[data_names$n > 1,]) +#' stop("Some models have duplicate module names.") +#' } +#' rm(data_names) +#' +#' # ------------------------------------------------------------------------------ +#' +#' engine_objects <- ls(pattern = "_engines$") +#' engine_objects <- engine_objects[engine_objects != "possible_engines"] +#' +#' #' @importFrom utils stack +#' get_engine_info <- function(x) { +#' y <- x +#' y <- get(y) +#' z <- stack(y) +#' z$mode <- rownames(y) +#' z$model <- gsub("_engines$", "", x) +#' z$object <- x +#' z <- z[z$values,] +#' z <- z[z$mode != "unknown",] +#' z$values <- NULL +#' names(z)[1] <- "engine" +#' z$engine <- as.character(z$engine) +#' z +#' } +#' +#' engine_info <- +#' purrr::map_df( +#' parsnip:::engine_objects, +#' get_engine_info +#' ) +#' +#' rm(engine_objects) +#' +#' ## nocov end diff --git a/README.md b/README.md index a4253bf86..a095fe62d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,12 @@ -[](https://travis-ci.org/tidymodels/parsnip) -[](https://codecov.io/github/tidymodels/parsnip?branch=master) - +[](https://travis-ci.org/tidymodels/parsnip) +[](https://codecov.io/github/tidymodels/parsnip?branch=master) +[](http://cran.r-project.org/web/packages/parsnip) +[](http://cran.rstudio.com/package=parsnip) +[](https://www.tidyverse.org/lifecycle/#maturing) + One issue with different functions available in R _that do the same thing_ is that they can have different interfaces and arguments. For example, to fit a random forest _classification_ model, we might have: diff --git a/_pkgdown.yml b/_pkgdown.yml index 0f4430bf4..644889998 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -53,7 +53,19 @@ reference: contents: - lending_club - wa_churn - - check_times + - check_times + + - title: Developer Tools + contents: + - set_new_model + - starts_with("set_model_") + - set_dependency + - set_fit + - set_pred + - show_model_info + - starts_with("get_") + + navbar: left: diff --git a/docs/dev/articles/articles/Classification.html b/docs/dev/articles/articles/Classification.html index f702a65ee..9f5ef1ac1 100644 --- a/docs/dev/articles/articles/Classification.html +++ b/docs/dev/articles/articles/Classification.html @@ -101,24 +101,24 @@
To demonstrate parsnip for classification models, the credit data will be used.
library(tidymodels)
+library(tidymodels)
#> Registered S3 method overwritten by 'xts':
#> method from
#> as.zoo.xts zoo
-#> ── Attaching packages ───────────────────────────────── tidymodels 0.0.2 ──
+#> ── Attaching packages ──────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.2 ──
#> ✔ broom 0.5.1 ✔ purrr 0.3.2
-#> ✔ dials 0.0.2 ✔ recipes 0.1.4
+#> ✔ dials 0.0.2 ✔ recipes 0.1.5
#> ✔ dplyr 0.8.0.1 ✔ rsample 0.0.4
#> ✔ infer 0.4.0 ✔ yardstick 0.0.2
-#> ── Conflicts ──────────────────────────────────── tidymodels_conflicts() ──
+#> ── Conflicts ─────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ recipes::step() masks stats::step()
-data(credit_data)
+data(credit_data)
-set.seed(7075)
+set.seed(7075)
data_split <- initial_split(credit_data, strata = "Status", p = 0.75)
credit_train <- training(data_split)
@@ -136,41 +136,42 @@ Classification Example
test_normalized <- bake(credit_rec, new_data = credit_test, all_predictors())
keras will be used to fit a model with 5 hidden units and uses a 10% dropout rate to regularize the model. At each training iteration (aka epoch) a random 20% of the data will be used to measure the cross-entropy of the model.
-set.seed(57974)
+set.seed(57974)
nnet_fit <-
mlp(epochs = 100, hidden_units = 5, dropout = 0.1) %>%
- # Also set engine-specific arguments:
- set_engine("keras", verbose = 0, validation_split = .20) %>%
- fit(Status ~ ., data = juice(credit_rec))
-
-nnet_fit
-#> parsnip model object
-#>
-#> Model
-#> ___________________________________________________________________________
-#> Layer (type) Output Shape Param #
-#> ===========================================================================
-#> dense_1 (Dense) (None, 5) 115
-#> ___________________________________________________________________________
-#> dense_2 (Dense) (None, 5) 30
-#> ___________________________________________________________________________
-#> dropout_1 (Dropout) (None, 5) 0
-#> ___________________________________________________________________________
-#> dense_3 (Dense) (None, 2) 12
-#> ===========================================================================
-#> Total params: 157
-#> Trainable params: 157
-#> Non-trainable params: 0
-#> ___________________________________________________________________________
+ set_mode("classification") %>%
+ # Also set engine-specific arguments:
+ set_engine("keras", verbose = 0, validation_split = .20) %>%
+ fit(Status ~ ., data = juice(credit_rec))
+
+nnet_fit
+#> parsnip model object
+#>
+#> Model
+#> ___________________________________________________________________________
+#> Layer (type) Output Shape Param #
+#> ===========================================================================
+#> dense_1 (Dense) (None, 5) 115
+#> ___________________________________________________________________________
+#> dense_2 (Dense) (None, 5) 30
+#> ___________________________________________________________________________
+#> dropout_1 (Dropout) (None, 5) 0
+#> ___________________________________________________________________________
+#> dense_3 (Dense) (None, 2) 12
+#> ===========================================================================
+#> Total params: 157
+#> Trainable params: 157
+#> Non-trainable params: 0
+#> ___________________________________________________________________________
In parsnip, the predict function can be used:.
test_results <-
credit_test %>%
select(Status) %>%
as_tibble() %>%
mutate(
- nnet_class = predict(nnet_fit, new_data = test_normalized) %>%
+ nnet_class = predict(nnet_fit, new_data = test_normalized) %>%
pull(.pred_class),
- nnet_prob = predict(nnet_fit, new_data = test_normalized, type = "prob") %>%
+ nnet_prob = predict(nnet_fit, new_data = test_normalized, type = "prob") %>%
pull(.pred_good)
)
@@ -183,12 +184,12 @@ Classification Example
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
-#> 1 accuracy binary 0.801
+#> 1 accuracy binary 0.805
test_results %>% conf_mat(truth = Status, nnet_class)
#> Truth
#> Prediction bad good
-#> bad 187 95
-#> good 126 705
+#> bad 185 89
+#> good 128 711check_mod_val.RdThese functions are similar to constructors and can be used to validate +that there are no conflicts with the underlying model structures used by the +package.
+ +pred_types + +check_mod_val(model, new = FALSE, existence = FALSE) + +get_model_env() + +check_mode_val(mode) + +check_engine_val(eng) + +check_arg_val(arg) + +check_submodels_val(has_submodel) + +check_func_val(func) + +check_fit_info(fit_obj) + +check_pred_info(pred_obj, type) + +check_pkg_val(pkg) + +set_new_model(model) + +set_model_mode(model, mode) + +set_model_engine(model, mode, eng) + +set_model_arg(model, eng, parsnip, original, func, has_submodel) + +set_dependency(model, eng, pkg) + +get_dependency(model) + +set_fit(model, mode, eng, value) + +get_fit(model) + +set_pred(model, mode, eng, type, value) + +get_pred_type(model, type) + +show_model_info(model) + +get_from_env(items)+ +
| model | +A single character string for the model type (e.g.
+ |
+
|---|---|
| new | +A single logical to check to see if the model that you are check +has not already been registered. |
+
| existence | +A single logical to check to see if the model has already +been registered. |
+
| mode | +A single character string for the model mode (e.g. "regression"). |
+
| eng | +A single character string for the model engine. |
+
| arg | +A single character string for the model argument name. |
+
| has_submodel | +A single logical for whether the argument +can make predictions on mutiple submodels at once. |
+
| func | +A named character vector that describes how to call
+a function. |
+
| fit_obj | +A list with elements |
+
| pred_obj | +A list with elements |
+
| type | +A single character value for the type of prediction. Possible +values are: +'raw', 'numeric', 'class', 'link', 'prob', 'conf_int', 'pred_int', 'quantile' +. |
+
| pkg | +An options character string for a package name. |
+
| parsnip | +A single character string for the "harmonized" argument name
+that |
+
| original | +A single character string for the argument name that +underlying model function uses. |
+
| value | +A list that conforms to the |
+
| items | +A character string of objects in the model environment. |
+
An object of class character of length 8.
These functions are available for users to add their
+own models or engines (in package or otherwise) so that they can
+be accessed using parsnip. This are more thoroughly documented
+on the package web site (see references below).
In short, parsnip stores an environment object that contains
+all of the information and code about how models are used (e.g.
+fitting, predicting, etc). These functions can be used to add
+models to that environment as well as helper functions that can
+be used to makes sure that the model data is in the right
+format.
"Making a parsnip model from scratch" +https://tidymodels.github.io/parsnip/articles/articles/Scratch.html
+ + ++# Show the information about a model: +show_model_info("rand_forest")#> Information for `rand_forest` +#> modes: unknown, classification, regression +#> +#> engines: +#> classification: randomForest, ranger, spark +#> regression: randomForest, ranger, spark +#> +#> arguments: +#> ranger: +#> mtry --> mtry +#> trees --> num.trees +#> min_n --> min.node.size +#> randomForest: +#> mtry --> mtry +#> trees --> ntree +#> min_n --> nodesize +#> spark: +#> mtry --> feature_subset_strategy +#> trees --> num_trees +#> min_n --> min_instances_per_node +#> +#> fit modules: +#> engine mode +#> ranger classification +#> ranger regression +#> randomForest classification +#> randomForest regression +#> spark classification +#> spark regression +#> +#> prediction modules: +#> mode engine methods +#> classification randomForest class, prob, raw +#> classification ranger class, conf_int, prob, raw +#> classification spark class, prob +#> regression randomForest numeric, raw +#> regression ranger conf_int, numeric, raw +#> regression spark numeric +#>#> [1] "boost_tree" "boost_tree_args" +#> [3] "boost_tree_fit" "boost_tree_modes" +#> [5] "boost_tree_pkgs" "boost_tree_predict" +#> [7] "decision_tree" "decision_tree_args" +#> [9] "decision_tree_fit" "decision_tree_modes" +#> [11] "decision_tree_pkgs" "decision_tree_predict" +#> [13] "linear_reg" "linear_reg_args" +#> [15] "linear_reg_fit" "linear_reg_modes" +#> [17] "linear_reg_pkgs" "linear_reg_predict" +#> [19] "logistic_reg" "logistic_reg_args" +#> [21] "logistic_reg_fit" "logistic_reg_modes" +#> [23] "logistic_reg_pkgs" "logistic_reg_predict" +#> [25] "mars" "mars_args" +#> [27] "mars_fit" "mars_modes" +#> [29] "mars_pkgs" "mars_predict" +#> [31] "mlp" "mlp_args" +#> [33] "mlp_fit" "mlp_modes" +#> [35] "mlp_pkgs" "mlp_predict" +#> [37] "models" "modes" +#> [39] "multinom_reg" "multinom_reg_args" +#> [41] "multinom_reg_fit" "multinom_reg_modes" +#> [43] "multinom_reg_pkgs" "multinom_reg_predict" +#> [45] "nearest_neighbor" "nearest_neighbor_args" +#> [47] "nearest_neighbor_fit" "nearest_neighbor_modes" +#> [49] "nearest_neighbor_pkgs" "nearest_neighbor_predict" +#> [51] "null_model" "null_model_args" +#> [53] "null_model_fit" "null_model_modes" +#> [55] "null_model_pkgs" "null_model_predict" +#> [57] "rand_forest" "rand_forest_args" +#> [59] "rand_forest_fit" "rand_forest_modes" +#> [61] "rand_forest_pkgs" "rand_forest_predict" +#> [63] "surv_reg" "surv_reg_args" +#> [65] "surv_reg_fit" "surv_reg_modes" +#> [67] "surv_reg_pkgs" "surv_reg_predict" +#> [69] "svm_poly" "svm_poly_args" +#> [71] "svm_poly_fit" "svm_poly_modes" +#> [73] "svm_poly_pkgs" "svm_poly_predict" +#> [75] "svm_rbf" "svm_rbf_args" +#> [77] "svm_rbf_fit" "svm_rbf_modes" +#> [79] "svm_rbf_pkgs" "svm_rbf_predict"+
#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: +#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 13626 obs. of 25 variables: #> $ package : chr "A3" "abbyyR" "abc" "abc.data" ... #> $ authors : int 1 1 1 1 5 3 2 1 4 6 ... #> $ imports : num 0 6 0 0 3 1 0 4 0 7 ... diff --git a/docs/dev/reference/decision_tree.html b/docs/dev/reference/decision_tree.html index bbb7cf34b..22825643e 100644 --- a/docs/dev/reference/decision_tree.html +++ b/docs/dev/reference/decision_tree.html @@ -159,7 +159,7 @@@@ -168,7 +168,7 @@General Interface for Decision Tree Models
time that the model is fit. Other options and argument can be set usingset_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified,update()can be used +functions. If parameters need to be modified,update()can be used in lieu of recreating the object from scratch.General Interface for Decision Tree Models
tree_depth = NULL, min_n = NULL) # S3 method for decision_tree -update(object, cost_complexity = NULL, +update(object, cost_complexity = NULL, tree_depth = NULL, min_n = NULL, fresh = FALSE, ...)Arguments
@@ -205,15 +205,15 @@Arg
... -+ Not used for
update().Not used for
update().Details
-The model can be created using the
fit()function using the +The model can be created using the
fit()function using the following engines:-
- +
R:
"rpart"or"C5.0"(classification only)R:
"rpart"(the default) or"C5.0"(classification only)Spark:
"spark"Note that, for
rpartmodels, butcost_complexityand @@ -226,13 +226,13 @@Note
For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via
@@ -269,7 +269,7 @@fit()is available; usingfit_xy()will +interface to viafit()is available; usingfit_xy()will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (viasave()), themodel$fitelement of theparsnip+R session (viasave()), themodel$fitelement of theparsnipobject should be serialized viaml_save(object$fit)and separately saved to disk. In a new session, the object can be reloaded and reattached to theparsnipobject.See also
- +Examples
@@ -288,12 +288,12 @@Examp #> Main Arguments: #> cost_complexity = 10 #> min_n = 3 -#>
#> Decision Tree Model Specification (unknown) +#>#> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 #> min_n = 3 -#>#> Decision Tree Model Specification (unknown) +#>#> Decision Tree Model Specification (unknown) #> #> Main Arguments: #> cost_complexity = 1 diff --git a/docs/dev/reference/descriptors.html b/docs/dev/reference/descriptors.html index cf0396989..22694b7e7 100644 --- a/docs/dev/reference/descriptors.html +++ b/docs/dev/reference/descriptors.html @@ -134,7 +134,7 @@diff --git a/docs/dev/reference/keras_mlp.html b/docs/dev/reference/keras_mlp.html index 51ffd154c..ed96c15ed 100644 --- a/docs/dev/reference/keras_mlp.html +++ b/docs/dev/reference/keras_mlp.html @@ -139,7 +139,7 @@Data Set Characteristics Available when Fitting Models
-When using the
fit()functions there are some +When using the
fit()functions there are some variables that will be available for use in arguments. For example, if the user would like to choose an argument value based on the current number of rows in a data set, the.obs()@@ -174,7 +174,7 @@Details
.y(): The known outcomes returned in the format given. Either a vector, matrix, or data frame.
.dat(): A data frame containing all of the predictors and the -outcomes. Iffit_xy()was used, the outcomes are attached as the +outcomes. Iffit_xy()was used, the outcomes are attached as the column,..y.For example, if you use the model formula
Sepal.Width ~ .with theiris@@ -200,7 +200,7 @@Details
To use these in a model fit, pass them to a model specification. The evaluation is delayed until the time when the -model is run via
fit()(and the variables listed above are available). +model is run viafit()(and the variables listed above are available). For example:data("lending_club") diff --git a/docs/dev/reference/fit.html b/docs/dev/reference/fit.html index 59c76b4e1..8bd1bbf94 100644 --- a/docs/dev/reference/fit.html +++ b/docs/dev/reference/fit.html @@ -132,18 +132,18 @@Fit a Model Specification to a Dataset
-
fit()andfit_xy()take a model specification, translate the required +
fit()andfit_xy()take a model specification, translate the required code by substituting arguments, and execute the model fit routine.# S3 method for model_spec -fit(object, formula = NULL, data = NULL, +fit(object, formula = NULL, data = NULL, control = fit_control(), ...) # S3 method for model_spec -fit_xy(object, x = NULL, y = NULL, +fit_xy(object, x = NULL, y = NULL, control = fit_control(), ...)Arguments
@@ -206,21 +206,24 @@Value
Details
-
fit()andfit_xy()substitute the current arguments in the model +
fit()andfit_xy()substitute the current arguments in the model specification into the computational engine's code, checks them for validity, then fits the model using the data and the engine-specific code. Different model functions have different interfaces (e.g. formula orx/y) and these functions translate -between the interface used whenfit()orfit_xy()were invoked and the one +between the interface used whenfit()orfit_xy()were invoked and the one required by the underlying model.When possible, these functions attempt to avoid making copies of the data. For example, if the underlying model uses a formula and -
+fit()is invoked, the original data are references +fit()is invoked, the original data are references when the model is fit. However, if the underlying model uses something else, such asx/y, the formula is evaluated and the data are converted to the required format. In this case, any calls in the resulting model objects reference the temporary objects used to fit the model.If the model engine has not been set, the model's default engine will be used +(as discussed on each model page). If the
verbosityoption of +fit_control()is greater than zero, a warning will be produced.See also
@@ -231,28 +234,26 @@Examp
# Although `glm()` only has a formula interface, different # methods for specifying the model can be used -library(dplyr)#> +library(dplyr)#> #>#> #> #>#> #> #>#> #> -#>data("lending_club") lr_mod <- logistic_reg() using_formula <- lr_mod %>% set_engine("glm") %>% - fit(Class ~ funded_amnt + int_rate, data = lending_club) + fit(Class ~ funded_amnt + int_rate, data = lending_club) using_xy <- lr_mod %>% set_engine("glm") %>% - fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], + fit_xy(x = lending_club[, c("funded_amnt", "int_rate")], y = lending_club$Class) using_formula#> parsnip model object diff --git a/docs/dev/reference/fit_control.html b/docs/dev/reference/fit_control.html index aa73f18a4..d93ee2181 100644 --- a/docs/dev/reference/fit_control.html +++ b/docs/dev/reference/fit_control.html @@ -131,7 +131,7 @@@@ -339,6 +359,7 @@Control the fit function
-@@ -154,7 +154,7 @@Options can be passed to the
fit()function that control the output and +Options can be passed to the
fit()function that control the output and computationsArg
diff --git a/docs/dev/reference/get_model_env.html b/docs/dev/reference/get_model_env.html new file mode 100644 index 000000000..0c2b8d9bb --- /dev/null +++ b/docs/dev/reference/get_model_env.html @@ -0,0 +1,191 @@ + + + + + + + + + catch A logical where a value of
TRUEwill evaluate -the model inside oftry(, silent = TRUE). If the model fails, +the model inside oftry(, silent = TRUE). If the model fails, an object is still returned (without an error) that inherits the class "try-error".Tools to Register Models — get_model_env • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +++ + + + + + diff --git a/docs/dev/reference/index.html b/docs/dev/reference/index.html index 3252d3df9..0540feb14 100644 --- a/docs/dev/reference/index.html +++ b/docs/dev/reference/index.html @@ -226,6 +226,12 @@+ + + + + +++ + +++ +++ +Tools to Register Models
+ ++get_model_env.Rd+ ++ +Tools to Register Models
+ +get_model_env() + +set_new_model(model) + +set_model_mode(model, mode) + +set_model_engine(model, mode, eng) + +set_model_arg(model, eng, parsnip, original, func, has_submodel) + +set_dependency(model, eng, pkg) + +get_dependency(model) + +set_fit(model, mode, eng, value) + +get_fit(model) + +set_pred(model, mode, eng, type, value) + +get_pred_type(model, type) + +show_model_info(model) + +get_from_env(items)+ + +add_rowindex() + +
+ Add a column of row numbers to a data frame
+ + +@@ -329,6 +335,20 @@ <
Execution Time Data
+ ++ +Developer Tools
+ ++ + + ++
pred_typescheck_mod_val()get_model_env()check_mode_val()check_engine_val()check_arg_val()check_submodels_val()check_func_val()check_fit_info()check_pred_info()check_pkg_val()set_new_model()set_model_mode()set_model_engine()set_model_arg()set_dependency()get_dependency()set_fit()get_fit()set_pred()get_pred_type()show_model_info()get_from_env()+ Tools to Register Models
Contents
Models Infrastructure Data +Developer Tools Simple interface to MLP models via keras
keras_mlp(x, y, hidden_units = 5, decay = 0, dropout = 0, - epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3), + epochs = 20, act = "softmax", seeds = sample.int(10^5, size = 3), ...)
#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: +#> Classes ‘tbl_df’, ‘tbl’ and 'data.frame': 9857 obs. of 23 variables: #> $ funded_amnt : int 16100 32000 10000 16800 3500 10000 11000 15000 6000 20000 ... #> $ term : Factor w/ 2 levels "term_36","term_60": 1 2 1 2 1 1 1 1 1 2 ... #> $ int_rate : num 13.99 11.99 16.29 13.67 7.39 ... diff --git a/docs/dev/reference/linear_reg.html b/docs/dev/reference/linear_reg.html index 35a3375f8..31dfc2fbf 100644 --- a/docs/dev/reference/linear_reg.html +++ b/docs/dev/reference/linear_reg.html @@ -155,7 +155,7 @@@@ -163,7 +163,7 @@General Interface for Linear Regression Models
time that the model is fit. Other options and argument can be set usingset_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified,update()can be used +functions. If parameters need to be modified,update()can be used in lieu of recreating the object from scratch.General Interface for Linear Regression Models
linear_reg(mode = "regression", penalty = NULL, mixture = NULL) # S3 method for linear_reg -update(object, penalty = NULL, mixture = NULL, +update(object, penalty = NULL, mixture = NULL, fresh = FALSE, ...)Arguments
@@ -200,7 +200,7 @@Arg
@@ -209,9 +209,9 @@ ... -+ Not used for
update().Not used for
update().Details
The data given to the function are not saved and are only used to determine the mode of the model. For
-linear_reg(), the mode will always be "regression".The model can be created using the
fit()function using the +The model can be created using the
fit()function using the following engines:-
- +
R:
"lm"or"glmnet"R:
"lm"(the default) or"glmnet"Stan:
"stan"Spark:
"spark"- @@ -221,13 +221,13 @@
keras:
"keras"Note
For models created using the spark engine, there are several differences to consider. First, only the formula -interface to via
@@ -265,8 +265,8 @@fit()is available; usingfit_xy()will +interface to viafit()is available; usingfit_xy()will generate an error. Second, the predictions will always be in a spark table format. The names will be the same as documented but without the dots. Third, there is no equivalent to factor columns in spark tables so class predictions are returned as character columns. Fourth, to retain the model object for a new -R session (viasave()), themodel$fitelement of theparsnip+R session (viasave()), themodel$fitelement of theparsnipobject should be serialized viaml_save(object$fit)and separately saved to disk. In a new session, the object can be reloaded and reattached to theparsnipobject.See also
- +Examples
@@ -296,12 +296,12 @@Examp #> Main Arguments: #> penalty = 10 #> mixture = 0.1 -#>
#> Linear Regression Model Specification (regression) +#>#> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 #> mixture = 0.1 -#>#> Linear Regression Model Specification (regression) +#>#> Linear Regression Model Specification (regression) #> #> Main Arguments: #> penalty = 1 diff --git a/docs/dev/reference/logistic_reg.html b/docs/dev/reference/logistic_reg.html index 7dde682bf..ca4332916 100644 --- a/docs/dev/reference/logistic_reg.html +++ b/docs/dev/reference/logistic_reg.html @@ -155,7 +155,7 @@@@ -163,7 +163,7 @@General Interface for Logistic Regression Models
time that the model is fit. Other options and argument can be set usingset_engine(). If left to their defaults here (NULL), the values are taken from the underlying model -functions. If parameters need to be modified,update()can be used +functions. If parameters need to be modified,update()can be used in lieu of recreating the object from scratch.General Interface for Logistic Regression Models
logistic_reg(mode = "classification", penalty = NULL, mixture = NULL) # S3 method for logistic_reg -update(object, penalty = NULL, mixture = NULL, +update(object, penalty = NULL, mixture = NULL, fresh = FALSE, ...)Arguments
@@ -200,16 +200,16 @@Arg
... -+ Not used for
update().Not used for
update().Details
For
-logistic_reg(), the mode will always be "classification".The model can be created using the
fit()function using the +The model can be created using the
fit()function using the following engines:
R: "glm" or "glmnet"
R: "glm" (the default) or "glmnet"
Stan: "stan"
Spark: "spark"
keras: "keras"
For models created using the spark engine, there are
several differences to consider. First, only the formula
-interface to via fit() is available; using fit_xy() will
+interface to via fit() is available; using fit_xy() will
generate an error. Second, the predictions will always be in a
spark table format. The names will be the same as documented but
without the dots. Third, there is no equivalent to factor
columns in spark tables so class predictions are returned as
character columns. Fourth, to retain the model object for a new
-R session (via save()), the model$fit element of the parsnip
+R session (via save()), the model$fit element of the parsnip
object should be serialized via ml_save(object$fit) and
separately saved to disk. In a new session, the object can be
reloaded and reattached to the parsnip object.
set_engine(). If left to their defaults
here (NULL), the values are taken from the underlying model
-functions. If parameters need to be modified, update() can be used
+functions. If parameters need to be modified, update() can be used
in lieu of recreating the object from scratch.
Not used for update().
Not used for update().
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "earth"
R: "earth" (the default)
Not used for update().
Not used for update().
set_engine(). If left to their defaults
here (see above), the values are taken from the underlying model
-functions. One exception is hidden_units when nnet::nnet is used; that
+functions. One exception is hidden_units when nnet::nnet is used; that
function's size argument has no default so a value of 5 units will be
used. Also, unless otherwise specified, the linout argument to
-nnet::nnet() will be set to TRUE when a regression model is created.
-If parameters need to be modified, update() can be used
+nnet::nnet() will be set to TRUE when a regression model is created.
+If parameters need to be modified, update() can be used
in lieu of recreating the object from scratch.
-The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "nnet"
R: "nnet" (the default)
keras: "keras"
An error is thrown if both penalty and dropout are specified for
@@ -271,7 +271,7 @@
set_engine(). If left to their defaults
here (NULL), the values are taken from the underlying model
-functions. If parameters need to be modified, update() can be used
+functions. If parameters need to be modified, update() can be used
in lieu of recreating the object from scratch.
A single integer for the number of neighbors
-to consider (often called k).
k). For kknn, a value of 5
+is used if neighbors is not specified.
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "kknn"
R: "kknn" (the default)
kknn::train.kknn(formula = missing_arg(), data = missing_arg(), - kmax = missing_arg()) + ks = 5)
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "parsnip"
| type | A single character value or |
|
|---|---|---|
| ... | -Not used for |
+ Not used for |
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "ranger" or "randomForest"
R: "ranger" (the default) or "randomForest"
Spark: "spark"
For models created using the spark engine, there are
several differences to consider. First, only the formula
-interface to via fit() is available; using fit_xy() will
+interface to via fit() is available; using fit_xy() will
generate an error. Second, the predictions will always be in a
spark table format. The names will be the same as documented but
without the dots. Third, there is no equivalent to factor
@@ -274,7 +274,7 @@
These objects are imported from other packages. Follow the links below to see their documentation.
%>%
surv_reg(mode = "regression", dist = NULL) # S3 method for surv_reg -update(object, dist = NULL, fresh = FALSE, ...)+update(object, dist = NULL, fresh = FALSE, ...)
| ... | -Not used for |
+ Not used for |
|---|
surv_reg(),the
mode will always be "regression".
Since survival models typically involve censoring (and require the use of
-survival::Surv() objects), the fit() function will require that the
+survival::Surv() objects), the fit() function will require that the
survival model be specified via the formula interface.
Also, for the flexsurv::flexsurvfit engine, the typical
strata function cannot be used. To achieve the same effect,
the extra parameter roles can be used (as described above).
For surv_reg(), the mode will always be "regression".
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "flexsurv", "survreg"
R: "flexsurv", "survival" (the default)
survival
survival::survreg(formula = missing_arg(), data = missing_arg(),
weights = missing_arg(), model = TRUE)
@@ -233,7 +233,7 @@ R
See also
-
+
Examples
@@ -249,7 +249,7 @@ Examp
#>
#> Main Arguments:
#> dist = weibull
-#>
set_engine(). If left to their defaults
here (NULL), the values are taken from the underlying model
-functions. If parameters need to be modified, update() can be used
+functions. If parameters need to be modified, update() can be used
in lieu of recreating the object from scratch.
Not used for update().
Not used for update().
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "kernlab"
R: "kernlab" (the default)
set_engine(). If left to their defaults
here (NULL), the values are taken from the underlying model
-functions. If parameters need to be modified, update() can be used
+functions. If parameters need to be modified, update() can be used
in lieu of recreating the object from scratch.
Not used for update().
Not used for update().
The model can be created using the fit() function using the
+
The model can be created using the fit() function using the
following engines:
R: "kernlab"
R: "kernlab" (the default)
translate() produces a template call that lacks the specific
argument values (such as data, etc). These are filled in once
-fit() is called with the specifics of the data for the model.
+fit() is called with the specifics of the data for the model.
The call may also include varying arguments if these are in
the specification.
It does contain the resolved argument names that are specific to diff --git a/docs/dev/reference/varying_args.model_spec.html b/docs/dev/reference/varying_args.model_spec.html index 22a240d23..526546ac9 100644 --- a/docs/dev/reference/varying_args.model_spec.html +++ b/docs/dev/reference/varying_args.model_spec.html @@ -132,20 +132,20 @@
varying_args() takes a model specification or a recipe and returns a tibble
+
varying_args() takes a model specification or a recipe and returns a tibble
of information on all possible varying arguments and whether or not they
are actually varying.
# S3 method for model_spec -varying_args(object, full = TRUE, ...) +varying_args(object, full = TRUE, ...) # S3 method for recipe -varying_args(object, full = TRUE, ...) +varying_args(object, full = TRUE, ...) # S3 method for step -varying_args(object, full = TRUE, ...)+varying_args(object, full = TRUE, ...)